You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/01/19 02:30:01 UTC

spark git commit: [SPARK-12700] [SQL] embed condition into SMJ and BroadcastHashJoin

Repository: spark
Updated Branches:
  refs/heads/master 39ac56fc6 -> 323d51f1d


[SPARK-12700] [SQL] embed condition into SMJ and BroadcastHashJoin

Currently SortMergeJoin and BroadcastHashJoin do not support condition, the need a followed Filter for that, the result projection to generate UnsafeRow could be very expensive if they generate lots of rows and could be filtered mostly by condition.

This PR brings the support of condition for SortMergeJoin and BroadcastHashJoin, just like other outer joins do.

This could improve the performance of Q72 by 7x (from 120s to 16.5s).

Author: Davies Liu <da...@databricks.com>

Closes #10653 from davies/filter_join.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/323d51f1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/323d51f1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/323d51f1

Branch: refs/heads/master
Commit: 323d51f1dadf733e413203d678cb3f76e4d68981
Parents: 39ac56f
Author: Davies Liu <da...@databricks.com>
Authored: Mon Jan 18 17:29:54 2016 -0800
Committer: Davies Liu <da...@gmail.com>
Committed: Mon Jan 18 17:29:54 2016 -0800

----------------------------------------------------------------------
 .../spark/sql/execution/SparkStrategies.scala   | 24 ++----
 .../sql/execution/joins/BroadcastHashJoin.scala |  1 +
 .../spark/sql/execution/joins/HashJoin.scala    | 81 ++++++++++++--------
 .../sql/execution/joins/HashOuterJoin.scala     |  5 +-
 .../sql/execution/joins/SortMergeJoin.scala     | 46 +++++++----
 .../sql/execution/joins/InnerJoinSuite.scala    | 11 +--
 6 files changed, 96 insertions(+), 72 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/323d51f1/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 910519d..df0f730 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.execution
 
+import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight}
 import org.apache.spark.sql.{execution, Strategy}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
@@ -77,33 +78,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
    */
   object EquiJoinSelection extends Strategy with PredicateHelper {
 
-    private[this] def makeBroadcastHashJoin(
-        leftKeys: Seq[Expression],
-        rightKeys: Seq[Expression],
-        left: LogicalPlan,
-        right: LogicalPlan,
-        condition: Option[Expression],
-        side: joins.BuildSide): Seq[SparkPlan] = {
-      val broadcastHashJoin = execution.joins.BroadcastHashJoin(
-        leftKeys, rightKeys, side, planLater(left), planLater(right))
-      condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil
-    }
-
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
 
       // --- Inner joins --------------------------------------------------------------------------
 
       case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
-        makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight)
+        joins.BroadcastHashJoin(
+          leftKeys, rightKeys, BuildRight, condition, planLater(left), planLater(right)) :: Nil
 
       case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
-        makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft)
+        joins.BroadcastHashJoin(
+          leftKeys, rightKeys, BuildLeft, condition, planLater(left), planLater(right)) :: Nil
 
       case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
         if RowOrdering.isOrderable(leftKeys) =>
-        val mergeJoin =
-          joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right))
-        condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil
+        joins.SortMergeJoin(
+          leftKeys, rightKeys, condition, planLater(left), planLater(right)) :: Nil
 
       // --- Outer joins --------------------------------------------------------------------------
 

http://git-wip-us.apache.org/repos/asf/spark/blob/323d51f1/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 0a818cc..c9ea579 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -39,6 +39,7 @@ case class BroadcastHashJoin(
     leftKeys: Seq[Expression],
     rightKeys: Seq[Expression],
     buildSide: BuildSide,
+    condition: Option[Expression],
     left: SparkPlan,
     right: SparkPlan)
   extends BinaryNode with HashJoin {

http://git-wip-us.apache.org/repos/asf/spark/blob/323d51f1/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 7f9d9da..8ef8540 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution.joins
 
+import java.util.NoSuchElementException
+
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.SparkPlan
@@ -29,6 +31,7 @@ trait HashJoin {
   val leftKeys: Seq[Expression]
   val rightKeys: Seq[Expression]
   val buildSide: BuildSide
+  val condition: Option[Expression]
   val left: SparkPlan
   val right: SparkPlan
 
@@ -50,6 +53,12 @@ trait HashJoin {
   protected def streamSideKeyGenerator: Projection =
     UnsafeProjection.create(streamedKeys, streamedPlan.output)
 
+  @transient private[this] lazy val boundCondition = if (condition.isDefined) {
+    newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
+  } else {
+    (r: InternalRow) => true
+  }
+
   protected def hashJoin(
       streamIter: Iterator[InternalRow],
       numStreamRows: LongSQLMetric,
@@ -68,44 +77,52 @@ trait HashJoin {
 
       private[this] val joinKeys = streamSideKeyGenerator
 
-      override final def hasNext: Boolean =
-        (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) ||
-          (streamIter.hasNext && fetchNext())
+      override final def hasNext: Boolean = {
+        while (true) {
+          // check if it's end of current matches
+          if (currentHashMatches != null && currentMatchPosition == currentHashMatches.length) {
+            currentHashMatches = null
+            currentMatchPosition = -1
+          }
 
-      override final def next(): InternalRow = {
-        val ret = buildSide match {
-          case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
-          case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
-        }
-        currentMatchPosition += 1
-        numOutputRows += 1
-        resultProjection(ret)
-      }
+          // find the next match
+          while (currentHashMatches == null && streamIter.hasNext) {
+            currentStreamedRow = streamIter.next()
+            numStreamRows += 1
+            val key = joinKeys(currentStreamedRow)
+            if (!key.anyNull) {
+              currentHashMatches = hashedRelation.get(key)
+              if (currentHashMatches != null) {
+                currentMatchPosition = 0
+              }
+            }
+          }
+          if (currentHashMatches == null) {
+            return false
+          }
 
-      /**
-       * Searches the streamed iterator for the next row that has at least one match in hashtable.
-       *
-       * @return true if the search is successful, and false if the streamed iterator runs out of
-       *         tuples.
-       */
-      private final def fetchNext(): Boolean = {
-        currentHashMatches = null
-        currentMatchPosition = -1
-
-        while (currentHashMatches == null && streamIter.hasNext) {
-          currentStreamedRow = streamIter.next()
-          numStreamRows += 1
-          val key = joinKeys(currentStreamedRow)
-          if (!key.anyNull) {
-            currentHashMatches = hashedRelation.get(key)
+          // found some matches
+          buildSide match {
+            case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
+            case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
+          }
+          if (boundCondition(joinRow)) {
+            return true
+          } else {
+            currentMatchPosition += 1
           }
         }
+        false  // unreachable
+      }
 
-        if (currentHashMatches == null) {
-          false
+      override final def next(): InternalRow = {
+        // next() could be called without calling hasNext()
+        if (hasNext) {
+          currentMatchPosition += 1
+          numOutputRows += 1
+          resultProjection(joinRow)
         } else {
-          currentMatchPosition = 0
-          true
+          throw new NoSuchElementException
         }
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/323d51f1/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
index 6d464d6..9e61430 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
@@ -78,8 +78,11 @@ trait HashOuterJoin {
 
   @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length)
   @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length)
-  @transient private[this] lazy val boundCondition =
+  @transient private[this] lazy val boundCondition = if (condition.isDefined) {
     newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
+  } else {
+    (row: InternalRow) => true
+  }
 
   // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
   // iterator for performance purpose.

http://git-wip-us.apache.org/repos/asf/spark/blob/323d51f1/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
index 812f881..322a954 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
 case class SortMergeJoin(
     leftKeys: Seq[Expression],
     rightKeys: Seq[Expression],
+    condition: Option[Expression],
     left: SparkPlan,
     right: SparkPlan) extends BinaryNode {
 
@@ -64,6 +65,13 @@ case class SortMergeJoin(
     val numOutputRows = longMetric("numOutputRows")
 
     left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
+      val boundCondition: (InternalRow) => Boolean = {
+        condition.map { cond =>
+          newPredicate(cond, left.output ++ right.output)
+        }.getOrElse {
+          (r: InternalRow) => true
+        }
+      }
       new RowIterator {
         // The projection used to extract keys from input rows of the left child.
         private[this] val leftKeyGenerator = UnsafeProjection.create(leftKeys, left.output)
@@ -89,26 +97,34 @@ case class SortMergeJoin(
         private[this] val resultProjection: (InternalRow) => InternalRow =
           UnsafeProjection.create(schema)
 
+        if (smjScanner.findNextInnerJoinRows()) {
+          currentRightMatches = smjScanner.getBufferedMatches
+          currentLeftRow = smjScanner.getStreamedRow
+          currentMatchIdx = 0
+        }
+
         override def advanceNext(): Boolean = {
-          if (currentMatchIdx == -1 || currentMatchIdx == currentRightMatches.length) {
-            if (smjScanner.findNextInnerJoinRows()) {
-              currentRightMatches = smjScanner.getBufferedMatches
-              currentLeftRow = smjScanner.getStreamedRow
-              currentMatchIdx = 0
-            } else {
-              currentRightMatches = null
-              currentLeftRow = null
-              currentMatchIdx = -1
+          while (currentMatchIdx >= 0) {
+            if (currentMatchIdx == currentRightMatches.length) {
+              if (smjScanner.findNextInnerJoinRows()) {
+                currentRightMatches = smjScanner.getBufferedMatches
+                currentLeftRow = smjScanner.getStreamedRow
+                currentMatchIdx = 0
+              } else {
+                currentRightMatches = null
+                currentLeftRow = null
+                currentMatchIdx = -1
+                return false
+              }
             }
-          }
-          if (currentLeftRow != null) {
             joinRow(currentLeftRow, currentRightMatches(currentMatchIdx))
             currentMatchIdx += 1
-            numOutputRows += 1
-            true
-          } else {
-            false
+            if (boundCondition(joinRow)) {
+              numOutputRows += 1
+              return true
+            }
           }
+          false
         }
 
         override def getRow: InternalRow = resultProjection(joinRow)

http://git-wip-us.apache.org/repos/asf/spark/blob/323d51f1/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index 42fadaa..ab81b70 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.sql.execution.joins
 
-import org.apache.spark.sql.{execution, DataFrame, Row, SQLConf}
 import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
 import org.apache.spark.sql.catalyst.plans.Inner
@@ -25,6 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Join
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
+import org.apache.spark.sql.{DataFrame, Row, SQLConf}
 
 class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
   import testImplicits.localSeqToDataFrameHolder
@@ -88,9 +88,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
         leftPlan: SparkPlan,
         rightPlan: SparkPlan,
         side: BuildSide) = {
-      val broadcastHashJoin =
-        execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan)
-      boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin)
+      joins.BroadcastHashJoin(leftKeys, rightKeys, side, boundCondition, leftPlan, rightPlan)
     }
 
     def makeSortMergeJoin(
@@ -100,9 +98,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
         leftPlan: SparkPlan,
         rightPlan: SparkPlan) = {
       val sortMergeJoin =
-        execution.joins.SortMergeJoin(leftKeys, rightKeys, leftPlan, rightPlan)
-      val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin)
-      EnsureRequirements(sqlContext).apply(filteredJoin)
+        joins.SortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan)
+      EnsureRequirements(sqlContext).apply(sortMergeJoin)
     }
 
     test(s"$testName using BroadcastHashJoin (build=left)") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org