You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/04/26 21:43:51 UTC

spark git commit: [SPARK-14853] [SQL] Support LeftSemi/LeftAnti in SortMergeJoinExec

Repository: spark
Updated Branches:
  refs/heads/master 89f082de0 -> 7131b03bc


[SPARK-14853] [SQL] Support LeftSemi/LeftAnti in SortMergeJoinExec

## What changes were proposed in this pull request?

This PR update SortMergeJoinExec to support LeftSemi/LeftAnti, so it could support all the join types, same as other three join implementations: BroadcastHashJoinExec, ShuffledHashJoinExec,and BroadcastNestedLoopJoinExec.

This PR also simplify the join selection in SparkStrategy.

## How was this patch tested?

Added new tests.

Author: Davies Liu <da...@databricks.com>

Closes #12668 from davies/smj_semi.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7131b03b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7131b03b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7131b03b

Branch: refs/heads/master
Commit: 7131b03bcf00cdda99e350f697946d4020a0822f
Parents: 89f082d
Author: Davies Liu <da...@databricks.com>
Authored: Tue Apr 26 12:43:47 2016 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Tue Apr 26 12:43:47 2016 -0700

----------------------------------------------------------------------
 .../spark/sql/execution/SparkPlanner.scala      |   8 +-
 .../spark/sql/execution/SparkStrategies.scala   | 192 +++++++------------
 .../execution/joins/CartesianProductExec.scala  |  20 +-
 .../sql/execution/joins/SortMergeJoinExec.scala |  91 +++++++--
 .../scala/org/apache/spark/sql/JoinSuite.scala  |  14 +-
 .../execution/joins/ExistenceJoinSuite.scala    |  12 ++
 .../sql/execution/joins/InnerJoinSuite.scala    |   2 +-
 .../sql/execution/metric/SQLMetricsSuite.scala  |  18 +-
 .../spark/sql/hive/HiveSessionState.scala       |   8 +-
 .../apache/spark/sql/hive/StatisticsSuite.scala |   4 +-
 10 files changed, 194 insertions(+), 175 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7131b03b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
index 0afa4c7..de832ec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
@@ -38,13 +38,9 @@ class SparkPlanner(
       DDLStrategy ::
       SpecialLimits ::
       Aggregation ::
-      ExistenceJoin ::
-      EquiJoinSelection ::
+      JoinSelection ::
       InMemoryScans ::
-      BasicOperators ::
-      BroadcastNestedLoop ::
-      CartesianProduct ::
-      DefaultJoin :: Nil)
+      BasicOperators :: Nil)
 
   /**
    * Used to build table scan operators where complex projection and filtering are done using

http://git-wip-us.apache.org/repos/asf/spark/blob/7131b03b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 3c10504..3955c5d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -64,39 +64,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
     }
   }
 
-  object ExistenceJoin extends Strategy with PredicateHelper {
-    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
-      case ExtractEquiJoinKeys(
-             LeftExistence(jt), leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
-        Seq(joins.BroadcastHashJoinExec(
-          leftKeys, rightKeys, jt, BuildRight, condition, planLater(left), planLater(right)))
-      // Find left semi joins where at least some predicates can be evaluated by matching join keys
-      case ExtractEquiJoinKeys(
-             LeftExistence(jt), leftKeys, rightKeys, condition, left, right) =>
-        Seq(joins.ShuffledHashJoinExec(
-          leftKeys, rightKeys, jt, BuildRight, condition, planLater(left), planLater(right)))
-      case _ => Nil
-    }
-  }
-
-  /**
-   * Matches a plan whose output should be small enough to be used in broadcast join.
-   */
-  object CanBroadcast {
-    def unapply(plan: LogicalPlan): Option[LogicalPlan] = {
-      if (plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold) {
-        Some(plan)
-      } else {
-        None
-      }
-    }
-  }
-
   /**
-   * Uses the [[ExtractEquiJoinKeys]] pattern to find joins where at least some of the predicates
-   * can be evaluated by matching join keys.
+   * Select the proper physical plan for join based on joining keys and size of logical plan.
    *
-   * Join implementations are chosen with the following precedence:
+   * At first, uses the [[ExtractEquiJoinKeys]] pattern to find joins where at least some of the
+   * predicates can be evaluated by matching join keys. If found,  Join implementations are chosen
+   * with the following precedence:
    *
    * - Broadcast: if one side of the join has an estimated physical size that is smaller than the
    *     user-configurable [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold
@@ -107,8 +80,20 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
    * - Shuffle hash join: if the average size of a single partition is small enough to build a hash
    *     table.
    * - Sort merge: if the matching join keys are sortable.
+   *
+   * If there is no joining keys, Join implementations are chosen with the following precedence:
+   * - BroadcastNestedLoopJoin: if one side of the join could be broadcasted
+   * - CartesianProduct: for Inner join
+   * - BroadcastNestedLoopJoin
    */
-  object EquiJoinSelection extends Strategy with PredicateHelper {
+  object JoinSelection extends Strategy with PredicateHelper {
+
+    /**
+     * Matches a plan whose output should be small enough to be used in broadcast join.
+     */
+    private def canBroadcast(plan: LogicalPlan): Boolean = {
+      plan.statistics.sizeInBytes <= conf.autoBroadcastJoinThreshold
+    }
 
     /**
      * Matches a plan whose single partition should be small enough to build a hash table.
@@ -116,7 +101,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
      * Note: this assume that the number of partition is fixed, requires additional work if it's
      * dynamic.
      */
-    def canBuildHashMap(plan: LogicalPlan): Boolean = {
+    private def canBuildLocalHashMap(plan: LogicalPlan): Boolean = {
       plan.statistics.sizeInBytes < conf.autoBroadcastJoinThreshold * conf.numShufflePartitions
     }
 
@@ -131,76 +116,80 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
       a.statistics.sizeInBytes * 3 <= b.statistics.sizeInBytes
     }
 
-    /**
-     * Returns whether we should use shuffle hash join or not.
-     *
-     * We should only use shuffle hash join when:
-     *  1) any single partition of a small table could fit in memory.
-     *  2) the smaller table is much smaller (3X) than the other one.
-     */
-    private def shouldShuffleHashJoin(left: LogicalPlan, right: LogicalPlan): Boolean = {
-      canBuildHashMap(left) && muchSmaller(left, right) ||
-        canBuildHashMap(right) && muchSmaller(right, left)
+    private def canBuildRight(joinType: JoinType): Boolean = joinType match {
+      case Inner | LeftOuter | LeftSemi | LeftAnti => true
+      case _ => false
     }
 
-    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+    private def canBuildLeft(joinType: JoinType): Boolean = joinType match {
+      case Inner | RightOuter => true
+      case _ => false
+    }
 
-      // --- Inner joins --------------------------------------------------------------------------
+    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
 
-      case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
-        Seq(joins.BroadcastHashJoinExec(
-          leftKeys, rightKeys, Inner, BuildRight, condition, planLater(left), planLater(right)))
+      // --- BroadcastHashJoin --------------------------------------------------------------------
 
-      case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
+      case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
+        if canBuildRight(joinType) && canBroadcast(right) =>
         Seq(joins.BroadcastHashJoinExec(
-          leftKeys, rightKeys, Inner, BuildLeft, condition, planLater(left), planLater(right)))
-
-      case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
-        if !conf.preferSortMergeJoin && shouldShuffleHashJoin(left, right) ||
-          !RowOrdering.isOrderable(leftKeys) =>
-        val buildSide =
-          if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
-            BuildRight
-          } else {
-            BuildLeft
-          }
-        Seq(joins.ShuffledHashJoinExec(
-          leftKeys, rightKeys, Inner, buildSide, condition, planLater(left), planLater(right)))
-
-      case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
-        if RowOrdering.isOrderable(leftKeys) =>
-        joins.SortMergeJoinExec(
-          leftKeys, rightKeys, Inner, condition, planLater(left), planLater(right)) :: Nil
-
-      // --- Outer joins --------------------------------------------------------------------------
+          leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right)))
 
-      case ExtractEquiJoinKeys(
-          LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
+      case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
+        if canBuildLeft(joinType) && canBroadcast(left) =>
         Seq(joins.BroadcastHashJoinExec(
-          leftKeys, rightKeys, LeftOuter, BuildRight, condition, planLater(left), planLater(right)))
+          leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right)))
 
-      case ExtractEquiJoinKeys(
-          RightOuter, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
-        Seq(joins.BroadcastHashJoinExec(
-          leftKeys, rightKeys, RightOuter, BuildLeft, condition, planLater(left), planLater(right)))
+      // --- ShuffledHashJoin ---------------------------------------------------------------------
 
-      case ExtractEquiJoinKeys(LeftOuter, leftKeys, rightKeys, condition, left, right)
-         if !conf.preferSortMergeJoin && canBuildHashMap(right) && muchSmaller(right, left) ||
+      case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
+         if !conf.preferSortMergeJoin && canBuildRight(joinType) && canBuildLocalHashMap(right)
+           && muchSmaller(right, left) ||
            !RowOrdering.isOrderable(leftKeys) =>
         Seq(joins.ShuffledHashJoinExec(
-          leftKeys, rightKeys, LeftOuter, BuildRight, condition, planLater(left), planLater(right)))
+          leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right)))
 
-      case ExtractEquiJoinKeys(RightOuter, leftKeys, rightKeys, condition, left, right)
-         if !conf.preferSortMergeJoin && canBuildHashMap(left) && muchSmaller(left, right) ||
+      case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
+         if !conf.preferSortMergeJoin && canBuildLeft(joinType) && canBuildLocalHashMap(left)
+           && muchSmaller(left, right) ||
            !RowOrdering.isOrderable(leftKeys) =>
         Seq(joins.ShuffledHashJoinExec(
-          leftKeys, rightKeys, RightOuter, BuildLeft, condition, planLater(left), planLater(right)))
+          leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right)))
+
+      // --- SortMergeJoin ------------------------------------------------------------
 
       case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
         if RowOrdering.isOrderable(leftKeys) =>
         joins.SortMergeJoinExec(
           leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
 
+      // --- Without joining keys ------------------------------------------------------------
+
+      // Pick BroadcastNestedLoopJoin if one side could be broadcasted
+      case j @ logical.Join(left, right, joinType, condition)
+          if canBuildRight(joinType) && canBroadcast(right) =>
+        joins.BroadcastNestedLoopJoinExec(
+          planLater(left), planLater(right), BuildRight, joinType, condition) :: Nil
+      case j @ logical.Join(left, right, joinType, condition)
+          if canBuildLeft(joinType) && canBroadcast(left) =>
+        joins.BroadcastNestedLoopJoinExec(
+          planLater(left), planLater(right), BuildLeft, joinType, condition) :: Nil
+
+      // Pick CartesianProduct for InnerJoin
+      case logical.Join(left, right, Inner, condition) =>
+        joins.CartesianProductExec(planLater(left), planLater(right), condition) :: Nil
+
+      case logical.Join(left, right, joinType, condition) =>
+        val buildSide =
+          if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
+            BuildRight
+          } else {
+            BuildLeft
+          }
+        // This join could be very slow or OOM
+        joins.BroadcastNestedLoopJoinExec(
+          planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
+
       // --- Cases where this strategy does not apply ---------------------------------------------
 
       case _ => Nil
@@ -277,45 +266,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
     }
   }
 
-  object BroadcastNestedLoop extends Strategy {
-    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
-      case j @ logical.Join(CanBroadcast(left), right, Inner | RightOuter, condition) =>
-        execution.joins.BroadcastNestedLoopJoinExec(
-          planLater(left), planLater(right), joins.BuildLeft, j.joinType, condition) :: Nil
-      case j @ logical.Join(left, CanBroadcast(right), Inner | LeftOuter | LeftSemi, condition) =>
-        execution.joins.BroadcastNestedLoopJoinExec(
-          planLater(left), planLater(right), joins.BuildRight, j.joinType, condition) :: Nil
-      case _ => Nil
-    }
-  }
-
-  object CartesianProduct extends Strategy {
-    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
-      case logical.Join(left, right, Inner, None) =>
-        execution.joins.CartesianProductExec(planLater(left), planLater(right)) :: Nil
-      case logical.Join(left, right, Inner, Some(condition)) =>
-        execution.FilterExec(condition,
-          execution.joins.CartesianProductExec(planLater(left), planLater(right))) :: Nil
-      case _ => Nil
-    }
-  }
-
-  object DefaultJoin extends Strategy {
-    def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
-      case logical.Join(left, right, joinType, condition) =>
-        val buildSide =
-          if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
-            joins.BuildRight
-          } else {
-            joins.BuildLeft
-          }
-        // This join could be very slow or even hang forever
-        joins.BroadcastNestedLoopJoinExec(
-          planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
-      case _ => Nil
-    }
-  }
-
   protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1)
 
   object InMemoryScans extends Strategy {

http://git-wip-us.apache.org/repos/asf/spark/blob/7131b03b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
index 3ce7c0e..67f5919 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins
 import org.apache.spark._
 import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD}
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, JoinedRow, UnsafeRow}
 import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
 import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
 import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -79,7 +79,10 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField
 }
 
 
-case class CartesianProductExec(left: SparkPlan, right: SparkPlan) extends BinaryExecNode {
+case class CartesianProductExec(
+    left: SparkPlan,
+    right: SparkPlan,
+    condition: Option[Expression]) extends BinaryExecNode {
   override def output: Seq[Attribute] = left.output ++ right.output
 
   override private[sql] lazy val metrics = Map(
@@ -94,7 +97,18 @@ case class CartesianProductExec(left: SparkPlan, right: SparkPlan) extends Binar
     val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size)
     pair.mapPartitionsInternal { iter =>
       val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
-      iter.map { r =>
+      val filtered = if (condition.isDefined) {
+        val boundCondition: (InternalRow) => Boolean =
+          newPredicate(condition.get, left.output ++ right.output)
+        val joined = new JoinedRow
+
+        iter.filter { r =>
+          boundCondition(joined(r._1, r._2))
+        }
+      } else {
+        iter
+      }
+      filtered.map { r =>
         numOutputRows += 1
         joiner.join(r._1, r._2)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/7131b03b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index 96b283a..a4c5491 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -53,6 +53,8 @@ case class SortMergeJoinExec(
         left.output.map(_.withNullability(true)) ++ right.output
       case FullOuter =>
         (left.output ++ right.output).map(_.withNullability(true))
+      case LeftExistence(_) =>
+        left.output
       case x =>
         throw new IllegalArgumentException(
           s"${getClass.getSimpleName} should not take $x as the JoinType")
@@ -65,6 +67,7 @@ case class SortMergeJoinExec(
     case LeftOuter => left.outputPartitioning
     case RightOuter => right.outputPartitioning
     case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
+    case LeftExistence(_) => left.outputPartitioning
     case x =>
       throw new IllegalArgumentException(
         s"${getClass.getSimpleName} should not take $x as the JoinType")
@@ -100,6 +103,7 @@ case class SortMergeJoinExec(
           (r: InternalRow) => true
         }
       }
+
       // An ordering that can be used to compare keys from both sides.
       val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
       val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output)
@@ -107,27 +111,17 @@ case class SortMergeJoinExec(
       joinType match {
         case Inner =>
           new RowIterator {
-            // The projection used to extract keys from input rows of the left child.
-            private[this] val leftKeyGenerator = UnsafeProjection.create(leftKeys, left.output)
-
-            // The projection used to extract keys from input rows of the right child.
-            private[this] val rightKeyGenerator = UnsafeProjection.create(rightKeys, right.output)
-
-            // An ordering that can be used to compare keys from both sides.
-            private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
             private[this] var currentLeftRow: InternalRow = _
             private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _
             private[this] var currentMatchIdx: Int = -1
             private[this] val smjScanner = new SortMergeJoinScanner(
-              leftKeyGenerator,
-              rightKeyGenerator,
+              createLeftKeyGenerator(),
+              createRightKeyGenerator(),
               keyOrdering,
               RowIterator.fromScala(leftIter),
               RowIterator.fromScala(rightIter)
             )
             private[this] val joinRow = new JoinedRow
-            private[this] val resultProjection: (InternalRow) => InternalRow =
-              UnsafeProjection.create(schema)
 
             if (smjScanner.findNextInnerJoinRows()) {
               currentRightMatches = smjScanner.getBufferedMatches
@@ -159,7 +153,7 @@ case class SortMergeJoinExec(
               false
             }
 
-            override def getRow: InternalRow = resultProjection(joinRow)
+            override def getRow: InternalRow = resultProj(joinRow)
           }.toScala
 
         case LeftOuter =>
@@ -204,6 +198,77 @@ case class SortMergeJoinExec(
             resultProj,
             numOutputRows).toScala
 
+        case LeftSemi =>
+          new RowIterator {
+            private[this] var currentLeftRow: InternalRow = _
+            private[this] val smjScanner = new SortMergeJoinScanner(
+              createLeftKeyGenerator(),
+              createRightKeyGenerator(),
+              keyOrdering,
+              RowIterator.fromScala(leftIter),
+              RowIterator.fromScala(rightIter)
+            )
+            private[this] val joinRow = new JoinedRow
+
+            override def advanceNext(): Boolean = {
+              while (smjScanner.findNextInnerJoinRows()) {
+                val currentRightMatches = smjScanner.getBufferedMatches
+                currentLeftRow = smjScanner.getStreamedRow
+                var i = 0
+                while (i < currentRightMatches.length) {
+                  joinRow(currentLeftRow, currentRightMatches(i))
+                  if (boundCondition(joinRow)) {
+                    numOutputRows += 1
+                    return true
+                  }
+                  i += 1
+                }
+              }
+              false
+            }
+
+            override def getRow: InternalRow = currentLeftRow
+          }.toScala
+
+        case LeftAnti =>
+          new RowIterator {
+            private[this] var currentLeftRow: InternalRow = _
+            private[this] val smjScanner = new SortMergeJoinScanner(
+              createLeftKeyGenerator(),
+              createRightKeyGenerator(),
+              keyOrdering,
+              RowIterator.fromScala(leftIter),
+              RowIterator.fromScala(rightIter)
+            )
+            private[this] val joinRow = new JoinedRow
+
+            override def advanceNext(): Boolean = {
+              while (smjScanner.findNextOuterJoinRows()) {
+                currentLeftRow = smjScanner.getStreamedRow
+                val currentRightMatches = smjScanner.getBufferedMatches
+                if (currentRightMatches == null) {
+                  return true
+                }
+                var i = 0
+                var found = false
+                while (!found && i < currentRightMatches.length) {
+                  joinRow(currentLeftRow, currentRightMatches(i))
+                  if (boundCondition(joinRow)) {
+                    found = true
+                  }
+                  i += 1
+                }
+                if (!found) {
+                  numOutputRows += 1
+                  return true
+                }
+              }
+              false
+            }
+
+            override def getRow: InternalRow = currentLeftRow
+          }.toScala
+
         case x =>
           throw new IllegalArgumentException(
             s"SortMergeJoin should not take $x as the JoinType")

http://git-wip-us.apache.org/repos/asf/spark/blob/7131b03b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index ef9bb7e..8cbad04 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -37,7 +37,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
     val x = testData2.as("x")
     val y = testData2.as("y")
     val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan
-    val planned = sqlContext.sessionState.planner.EquiJoinSelection(join)
+    val planned = sqlContext.sessionState.planner.JoinSelection(join)
     assert(planned.size === 1)
   }
 
@@ -65,7 +65,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
     withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "0") {
       Seq(
         ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
-          classOf[ShuffledHashJoinExec]),
+          classOf[SortMergeJoinExec]),
         ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[BroadcastNestedLoopJoinExec]),
         ("SELECT * FROM testData JOIN testData2", classOf[CartesianProductExec]),
         ("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProductExec]),
@@ -99,7 +99,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
           classOf[BroadcastNestedLoopJoinExec]),
         ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)",
           classOf[BroadcastNestedLoopJoinExec]),
-        ("SELECT * FROM testData ANTI JOIN testData2 ON key = a", classOf[ShuffledHashJoinExec]),
+        ("SELECT * FROM testData ANTI JOIN testData2 ON key = a", classOf[SortMergeJoinExec]),
         ("SELECT * FROM testData LEFT ANTI JOIN testData2", classOf[BroadcastNestedLoopJoinExec])
       ).foreach(assertJoin)
     }
@@ -144,7 +144,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
     val x = testData2.as("x")
     val y = testData2.as("y")
     val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan
-    val planned = sqlContext.sessionState.planner.EquiJoinSelection(join)
+    val planned = sqlContext.sessionState.planner.JoinSelection(join)
     assert(planned.size === 1)
   }
 
@@ -449,9 +449,9 @@ class JoinSuite extends QueryTest with SharedSQLContext {
     withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
       Seq(
         ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
-          classOf[ShuffledHashJoinExec]),
+          classOf[SortMergeJoinExec]),
         ("SELECT * FROM testData LEFT ANTI JOIN testData2 ON key = a",
-          classOf[ShuffledHashJoinExec])
+          classOf[SortMergeJoinExec])
       ).foreach(assertJoin)
     }
 
@@ -475,7 +475,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
 
       Seq(
         ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
-          classOf[ShuffledHashJoinExec]),
+          classOf[SortMergeJoinExec]),
         ("SELECT * FROM testData LEFT SEMI JOIN testData2",
           classOf[BroadcastNestedLoopJoinExec]),
         ("SELECT * FROM testData JOIN testData2",

http://git-wip-us.apache.org/repos/asf/spark/blob/7131b03b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
index bc838ee..c7c10ab 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala
@@ -104,6 +104,18 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext {
       }
     }
 
+    test(s"$testName using SortMergeJoin") {
+      extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
+        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+            EnsureRequirements(left.sqlContext.sessionState.conf).apply(
+              SortMergeJoinExec(leftKeys, rightKeys, joinType, boundCondition, left, right)),
+            expectedAnswer,
+            sortAnswers = true)
+        }
+      }
+    }
+
     test(s"$testName using BroadcastNestedLoopJoin build left") {
       withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
         checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/7131b03b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index 933f32e..2a4a369 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -189,7 +189,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
     test(s"$testName using CartesianProduct") {
       withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
         checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
-          FilterExec(condition(), CartesianProductExec(left, right)),
+          CartesianProductExec(left, right, Some(condition())),
           expectedAnswer.map(Row.fromTuple),
           sortAnswers = true)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/7131b03b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 695b182..1859c6e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -255,28 +255,14 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
     val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value")
     val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value")
     // Assume the execution plan is
-    // ... -> BroadcastLeftSemiJoinHash(nodeId = 0)
+    // ... -> BroadcastHashJoin(nodeId = 0)
     val df = df1.join(broadcast(df2), $"key" === $"key2", "leftsemi")
     testSparkPlanMetrics(df, 2, Map(
-      0L -> ("BroadcastLeftSemiJoinHash", Map(
+      0L -> ("BroadcastHashJoin", Map(
         "number of output rows" -> 2L)))
     )
   }
 
-  test("ShuffledHashJoin metrics") {
-    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
-      val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value")
-      val df2 = Seq((1, "1"), (2, "2"), (3, "3"), (4, "4")).toDF("key2", "value")
-      // Assume the execution plan is
-      // ... -> ShuffledHashJoin(nodeId = 0)
-      val df = df1.join(df2, $"key" === $"key2", "leftsemi")
-      testSparkPlanMetrics(df, 1, Map(
-        0L -> ("ShuffledHashJoin", Map(
-          "number of output rows" -> 2L)))
-      )
-    }
-  }
-
   test("CartesianProduct metrics") {
     val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
     testDataForJoin.registerTempTable("testDataForJoin")

http://git-wip-us.apache.org/repos/asf/spark/blob/7131b03b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
index 4a8978e..9633f9e 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
@@ -120,12 +120,8 @@ private[hive] class HiveSessionState(sparkSession: SparkSession)
           DataSinks,
           Scripts,
           Aggregation,
-          ExistenceJoin,
-          EquiJoinSelection,
-          BasicOperators,
-          BroadcastNestedLoop,
-          CartesianProduct,
-          DefaultJoin
+          JoinSelection,
+          BasicOperators
         )
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/7131b03b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
index 93a6f0b..f6b5101 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
@@ -228,10 +228,10 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton {
       assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off")
 
       val shj = df.queryExecution.sparkPlan.collect {
-        case j: ShuffledHashJoinExec => j
+        case j: SortMergeJoinExec => j
       }
       assert(shj.size === 1,
-        "LeftSemiJoinHash should be planned when BroadcastHashJoin is turned off")
+        "SortMergeJoinExec should be planned when BroadcastHashJoin is turned off")
 
       sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=$tmp")
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org