You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/03/16 03:58:55 UTC

spark git commit: [SPARK-13918][SQL] Merge SortMergeJoin and SortMergerOuterJoin

Repository: spark
Updated Branches:
  refs/heads/master 643649dcb -> bbd887f53


[SPARK-13918][SQL] Merge SortMergeJoin and SortMergerOuterJoin

## What changes were proposed in this pull request?

This PR just move some code from SortMergeOuterJoin into SortMergeJoin.

This is for support codegen for outer join.

## How was this patch tested?

existing tests.

Author: Davies Liu <da...@databricks.com>

Closes #11743 from davies/gen_smjouter.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bbd887f5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bbd887f5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bbd887f5

Branch: refs/heads/master
Commit: bbd887f53cc4fa03d97932e1b570bd7180783da5
Parents: 643649d
Author: Davies Liu <da...@databricks.com>
Authored: Tue Mar 15 19:58:49 2016 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Tue Mar 15 19:58:49 2016 -0700

----------------------------------------------------------------------
 .../spark/sql/execution/SparkStrategies.scala   |   4 +-
 .../spark/sql/execution/WholeStageCodegen.scala |   3 +-
 .../sql/execution/joins/SortMergeJoin.scala     | 505 +++++++++++++++++--
 .../execution/joins/SortMergeOuterJoin.scala    | 464 -----------------
 .../scala/org/apache/spark/sql/JoinSuite.scala  |   7 +-
 .../spark/sql/execution/PlannerSuite.scala      |   3 +
 .../sql/execution/joins/InnerJoinSuite.scala    |   2 +-
 .../sql/execution/joins/OuterJoinSuite.scala    |   4 +-
 .../sql/execution/metric/SQLMetricsSuite.scala  |  10 +-
 9 files changed, 467 insertions(+), 535 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bbd887f5/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 113cf9a..7fc6a82 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -120,7 +120,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
       case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
         if RowOrdering.isOrderable(leftKeys) =>
         joins.SortMergeJoin(
-          leftKeys, rightKeys, condition, planLater(left), planLater(right)) :: Nil
+          leftKeys, rightKeys, Inner, condition, planLater(left), planLater(right)) :: Nil
 
       // --- Outer joins --------------------------------------------------------------------------
 
@@ -136,7 +136,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
 
       case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
         if RowOrdering.isOrderable(leftKeys) =>
-        joins.SortMergeOuterJoin(
+        joins.SortMergeJoin(
           leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
 
       // --- Cases where this strategy does not apply ---------------------------------------------

http://git-wip-us.apache.org/repos/asf/spark/blob/bbd887f5/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 81676d3..a54b772 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -22,6 +22,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.plans.Inner
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.util.toCommentSafeString
@@ -450,7 +451,7 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] {
    * Inserts a InputAdapter on top of those that do not support codegen.
    */
   private def insertInputAdapter(plan: SparkPlan): SparkPlan = plan match {
-    case j @ SortMergeJoin(_, _, _, left, right) =>
+    case j @ SortMergeJoin(_, _, _, _, left, right) if j.supportCodegen =>
       // The children of SortMergeJoin should do codegen separately.
       j.copy(left = InputAdapter(insertWholeStageCodegen(left)),
         right = InputAdapter(insertWholeStageCodegen(right)))

http://git-wip-us.apache.org/repos/asf/spark/blob/bbd887f5/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
index cffd6f6..d0724ff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
@@ -23,9 +23,11 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, RowIterator, SparkPlan}
-import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
+import org.apache.spark.util.collection.BitSet
 
 /**
  * Performs an sort merge join of two child relations.
@@ -33,6 +35,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
 case class SortMergeJoin(
     leftKeys: Seq[Expression],
     rightKeys: Seq[Expression],
+    joinType: JoinType,
     condition: Option[Expression],
     left: SparkPlan,
     right: SparkPlan) extends BinaryNode with CodegenSupport {
@@ -40,10 +43,32 @@ case class SortMergeJoin(
   override private[sql] lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
 
-  override def output: Seq[Attribute] = left.output ++ right.output
+  override def output: Seq[Attribute] = {
+    joinType match {
+      case Inner =>
+        left.output ++ right.output
+      case LeftOuter =>
+        left.output ++ right.output.map(_.withNullability(true))
+      case RightOuter =>
+        left.output.map(_.withNullability(true)) ++ right.output
+      case FullOuter =>
+        (left.output ++ right.output).map(_.withNullability(true))
+      case x =>
+        throw new IllegalArgumentException(
+          s"${getClass.getSimpleName} should not take $x as the JoinType")
+    }
+  }
 
-  override def outputPartitioning: Partitioning =
-    PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
+  override def outputPartitioning: Partitioning = joinType match {
+    case Inner => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning))
+    // For left and right outer joins, the output is partitioned by the streamed input's join keys.
+    case LeftOuter => left.outputPartitioning
+    case RightOuter => right.outputPartitioning
+    case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
+    case x =>
+      throw new IllegalArgumentException(
+        s"${getClass.getSimpleName} should not take $x as the JoinType")
+  }
 
   override def requiredChildDistribution: Seq[Distribution] =
     ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
@@ -58,6 +83,12 @@ case class SortMergeJoin(
     keys.map(SortOrder(_, Ascending))
   }
 
+  private def createLeftKeyGenerator(): Projection =
+    UnsafeProjection.create(leftKeys, left.output)
+
+  private def createRightKeyGenerator(): Projection =
+    UnsafeProjection.create(rightKeys, right.output)
+
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
 
@@ -69,64 +100,122 @@ case class SortMergeJoin(
           (r: InternalRow) => true
         }
       }
-      new RowIterator {
-        // The projection used to extract keys from input rows of the left child.
-        private[this] val leftKeyGenerator = UnsafeProjection.create(leftKeys, left.output)
-
-        // The projection used to extract keys from input rows of the right child.
-        private[this] val rightKeyGenerator = UnsafeProjection.create(rightKeys, right.output)
-
-        // An ordering that can be used to compare keys from both sides.
-        private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
-        private[this] var currentLeftRow: InternalRow = _
-        private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _
-        private[this] var currentMatchIdx: Int = -1
-        private[this] val smjScanner = new SortMergeJoinScanner(
-          leftKeyGenerator,
-          rightKeyGenerator,
-          keyOrdering,
-          RowIterator.fromScala(leftIter),
-          RowIterator.fromScala(rightIter)
-        )
-        private[this] val joinRow = new JoinedRow
-        private[this] val resultProjection: (InternalRow) => InternalRow =
-          UnsafeProjection.create(schema)
-
-        if (smjScanner.findNextInnerJoinRows()) {
-          currentRightMatches = smjScanner.getBufferedMatches
-          currentLeftRow = smjScanner.getStreamedRow
-          currentMatchIdx = 0
-        }
+      // An ordering that can be used to compare keys from both sides.
+      val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
+      val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output)
+
+      joinType match {
+        case Inner =>
+          new RowIterator {
+            // The projection used to extract keys from input rows of the left child.
+            private[this] val leftKeyGenerator = UnsafeProjection.create(leftKeys, left.output)
+
+            // The projection used to extract keys from input rows of the right child.
+            private[this] val rightKeyGenerator = UnsafeProjection.create(rightKeys, right.output)
+
+            // An ordering that can be used to compare keys from both sides.
+            private[this] val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
+            private[this] var currentLeftRow: InternalRow = _
+            private[this] var currentRightMatches: ArrayBuffer[InternalRow] = _
+            private[this] var currentMatchIdx: Int = -1
+            private[this] val smjScanner = new SortMergeJoinScanner(
+              leftKeyGenerator,
+              rightKeyGenerator,
+              keyOrdering,
+              RowIterator.fromScala(leftIter),
+              RowIterator.fromScala(rightIter)
+            )
+            private[this] val joinRow = new JoinedRow
+            private[this] val resultProjection: (InternalRow) => InternalRow =
+              UnsafeProjection.create(schema)
+
+            if (smjScanner.findNextInnerJoinRows()) {
+              currentRightMatches = smjScanner.getBufferedMatches
+              currentLeftRow = smjScanner.getStreamedRow
+              currentMatchIdx = 0
+            }
 
-        override def advanceNext(): Boolean = {
-          while (currentMatchIdx >= 0) {
-            if (currentMatchIdx == currentRightMatches.length) {
-              if (smjScanner.findNextInnerJoinRows()) {
-                currentRightMatches = smjScanner.getBufferedMatches
-                currentLeftRow = smjScanner.getStreamedRow
-                currentMatchIdx = 0
-              } else {
-                currentRightMatches = null
-                currentLeftRow = null
-                currentMatchIdx = -1
-                return false
+            override def advanceNext(): Boolean = {
+              while (currentMatchIdx >= 0) {
+                if (currentMatchIdx == currentRightMatches.length) {
+                  if (smjScanner.findNextInnerJoinRows()) {
+                    currentRightMatches = smjScanner.getBufferedMatches
+                    currentLeftRow = smjScanner.getStreamedRow
+                    currentMatchIdx = 0
+                  } else {
+                    currentRightMatches = null
+                    currentLeftRow = null
+                    currentMatchIdx = -1
+                    return false
+                  }
+                }
+                joinRow(currentLeftRow, currentRightMatches(currentMatchIdx))
+                currentMatchIdx += 1
+                if (boundCondition(joinRow)) {
+                  numOutputRows += 1
+                  return true
+                }
               }
+              false
             }
-            joinRow(currentLeftRow, currentRightMatches(currentMatchIdx))
-            currentMatchIdx += 1
-            if (boundCondition(joinRow)) {
-              numOutputRows += 1
-              return true
-            }
-          }
-          false
-        }
 
-        override def getRow: InternalRow = resultProjection(joinRow)
-      }.toScala
+            override def getRow: InternalRow = resultProjection(joinRow)
+          }.toScala
+
+        case LeftOuter =>
+          val smjScanner = new SortMergeJoinScanner(
+            streamedKeyGenerator = createLeftKeyGenerator(),
+            bufferedKeyGenerator = createRightKeyGenerator(),
+            keyOrdering,
+            streamedIter = RowIterator.fromScala(leftIter),
+            bufferedIter = RowIterator.fromScala(rightIter)
+          )
+          val rightNullRow = new GenericInternalRow(right.output.length)
+          new LeftOuterIterator(
+            smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows).toScala
+
+        case RightOuter =>
+          val smjScanner = new SortMergeJoinScanner(
+            streamedKeyGenerator = createRightKeyGenerator(),
+            bufferedKeyGenerator = createLeftKeyGenerator(),
+            keyOrdering,
+            streamedIter = RowIterator.fromScala(rightIter),
+            bufferedIter = RowIterator.fromScala(leftIter)
+          )
+          val leftNullRow = new GenericInternalRow(left.output.length)
+          new RightOuterIterator(
+            smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala
+
+        case FullOuter =>
+          val leftNullRow = new GenericInternalRow(left.output.length)
+          val rightNullRow = new GenericInternalRow(right.output.length)
+          val smjScanner = new SortMergeFullOuterJoinScanner(
+            leftKeyGenerator = createLeftKeyGenerator(),
+            rightKeyGenerator = createRightKeyGenerator(),
+            keyOrdering,
+            leftIter = RowIterator.fromScala(leftIter),
+            rightIter = RowIterator.fromScala(rightIter),
+            boundCondition,
+            leftNullRow,
+            rightNullRow)
+
+          new FullOuterIterator(
+            smjScanner,
+            resultProj,
+            numOutputRows).toScala
+
+        case x =>
+          throw new IllegalArgumentException(
+            s"SortMergeJoin should not take $x as the JoinType")
+      }
+
     }
   }
 
+  override def supportCodegen: Boolean = {
+    joinType == Inner
+  }
+
   override def upstreams(): Seq[RDD[InternalRow]] = {
     left.execute() :: right.execute() :: Nil
   }
@@ -376,7 +465,7 @@ case class SortMergeJoin(
 }
 
 /**
- * Helper class that is used to implement [[SortMergeJoin]] and [[SortMergeOuterJoin]].
+ * Helper class that is used to implement [[SortMergeJoin]].
  *
  * To perform an inner (outer) join, users of this class call [[findNextInnerJoinRows()]]
  * ([[findNextOuterJoinRows()]]), which returns `true` if a result has been produced and `false`
@@ -570,3 +659,307 @@ private[joins] class SortMergeJoinScanner(
     } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0)
   }
 }
+
+/**
+ * An iterator for outputting rows in left outer join.
+ */
+private class LeftOuterIterator(
+  smjScanner: SortMergeJoinScanner,
+  rightNullRow: InternalRow,
+  boundCondition: InternalRow => Boolean,
+  resultProj: InternalRow => InternalRow,
+  numOutputRows: LongSQLMetric)
+  extends OneSideOuterIterator(
+    smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows) {
+
+  protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row)
+  protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withRight(row)
+}
+
+/**
+ * An iterator for outputting rows in right outer join.
+ */
+private class RightOuterIterator(
+  smjScanner: SortMergeJoinScanner,
+  leftNullRow: InternalRow,
+  boundCondition: InternalRow => Boolean,
+  resultProj: InternalRow => InternalRow,
+  numOutputRows: LongSQLMetric)
+  extends OneSideOuterIterator(
+    smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) {
+
+  protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withRight(row)
+  protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row)
+}
+
+/**
+ * An abstract iterator for sharing code between [[LeftOuterIterator]] and [[RightOuterIterator]].
+ *
+ * Each [[OneSideOuterIterator]] has a streamed side and a buffered side. Each row on the
+ * streamed side will output 0 or many rows, one for each matching row on the buffered side.
+ * If there are no matches, then the buffered side of the joined output will be a null row.
+ *
+ * In left outer join, the left is the streamed side and the right is the buffered side.
+ * In right outer join, the right is the streamed side and the left is the buffered side.
+ *
+ * @param smjScanner a scanner that streams rows and buffers any matching rows
+ * @param bufferedSideNullRow the default row to return when a streamed row has no matches
+ * @param boundCondition an additional filter condition for buffered rows
+ * @param resultProj how the output should be projected
+ * @param numOutputRows an accumulator metric for the number of rows output
+ */
+private abstract class OneSideOuterIterator(
+  smjScanner: SortMergeJoinScanner,
+  bufferedSideNullRow: InternalRow,
+  boundCondition: InternalRow => Boolean,
+  resultProj: InternalRow => InternalRow,
+  numOutputRows: LongSQLMetric) extends RowIterator {
+
+  // A row to store the joined result, reused many times
+  protected[this] val joinedRow: JoinedRow = new JoinedRow()
+
+  // Index of the buffered rows, reset to 0 whenever we advance to a new streamed row
+  private[this] var bufferIndex: Int = 0
+
+  // This iterator is initialized lazily so there should be no matches initially
+  assert(smjScanner.getBufferedMatches.length == 0)
+
+  // Set output methods to be overridden by subclasses
+  protected def setStreamSideOutput(row: InternalRow): Unit
+  protected def setBufferedSideOutput(row: InternalRow): Unit
+
+  /**
+   * Advance to the next row on the stream side and populate the buffer with matches.
+   * @return whether there are more rows in the stream to consume.
+   */
+  private def advanceStream(): Boolean = {
+    bufferIndex = 0
+    if (smjScanner.findNextOuterJoinRows()) {
+      setStreamSideOutput(smjScanner.getStreamedRow)
+      if (smjScanner.getBufferedMatches.isEmpty) {
+        // There are no matching rows in the buffer, so return the null row
+        setBufferedSideOutput(bufferedSideNullRow)
+      } else {
+        // Find the next row in the buffer that satisfied the bound condition
+        if (!advanceBufferUntilBoundConditionSatisfied()) {
+          setBufferedSideOutput(bufferedSideNullRow)
+        }
+      }
+      true
+    } else {
+      // Stream has been exhausted
+      false
+    }
+  }
+
+  /**
+   * Advance to the next row in the buffer that satisfies the bound condition.
+   * @return whether there is such a row in the current buffer.
+   */
+  private def advanceBufferUntilBoundConditionSatisfied(): Boolean = {
+    var foundMatch: Boolean = false
+    while (!foundMatch && bufferIndex < smjScanner.getBufferedMatches.length) {
+      setBufferedSideOutput(smjScanner.getBufferedMatches(bufferIndex))
+      foundMatch = boundCondition(joinedRow)
+      bufferIndex += 1
+    }
+    foundMatch
+  }
+
+  override def advanceNext(): Boolean = {
+    val r = advanceBufferUntilBoundConditionSatisfied() || advanceStream()
+    if (r) numOutputRows += 1
+    r
+  }
+
+  override def getRow: InternalRow = resultProj(joinedRow)
+}
+
+private class SortMergeFullOuterJoinScanner(
+  leftKeyGenerator: Projection,
+  rightKeyGenerator: Projection,
+  keyOrdering: Ordering[InternalRow],
+  leftIter: RowIterator,
+  rightIter: RowIterator,
+  boundCondition: InternalRow => Boolean,
+  leftNullRow: InternalRow,
+  rightNullRow: InternalRow)  {
+  private[this] val joinedRow: JoinedRow = new JoinedRow()
+  private[this] var leftRow: InternalRow = _
+  private[this] var leftRowKey: InternalRow = _
+  private[this] var rightRow: InternalRow = _
+  private[this] var rightRowKey: InternalRow = _
+
+  private[this] var leftIndex: Int = 0
+  private[this] var rightIndex: Int = 0
+  private[this] val leftMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow]
+  private[this] val rightMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow]
+  private[this] var leftMatched: BitSet = new BitSet(1)
+  private[this] var rightMatched: BitSet = new BitSet(1)
+
+  advancedLeft()
+  advancedRight()
+
+  // --- Private methods --------------------------------------------------------------------------
+
+  /**
+   * Advance the left iterator and compute the new row's join key.
+   * @return true if the left iterator returned a row and false otherwise.
+   */
+  private def advancedLeft(): Boolean = {
+    if (leftIter.advanceNext()) {
+      leftRow = leftIter.getRow
+      leftRowKey = leftKeyGenerator(leftRow)
+      true
+    } else {
+      leftRow = null
+      leftRowKey = null
+      false
+    }
+  }
+
+  /**
+   * Advance the right iterator and compute the new row's join key.
+   * @return true if the right iterator returned a row and false otherwise.
+   */
+  private def advancedRight(): Boolean = {
+    if (rightIter.advanceNext()) {
+      rightRow = rightIter.getRow
+      rightRowKey = rightKeyGenerator(rightRow)
+      true
+    } else {
+      rightRow = null
+      rightRowKey = null
+      false
+    }
+  }
+
+  /**
+   * Populate the left and right buffers with rows matching the provided key.
+   * This consumes rows from both iterators until their keys are different from the matching key.
+   */
+  private def findMatchingRows(matchingKey: InternalRow): Unit = {
+    leftMatches.clear()
+    rightMatches.clear()
+    leftIndex = 0
+    rightIndex = 0
+
+    while (leftRowKey != null && keyOrdering.compare(leftRowKey, matchingKey) == 0) {
+      leftMatches += leftRow.copy()
+      advancedLeft()
+    }
+    while (rightRowKey != null && keyOrdering.compare(rightRowKey, matchingKey) == 0) {
+      rightMatches += rightRow.copy()
+      advancedRight()
+    }
+
+    if (leftMatches.size <= leftMatched.capacity) {
+      leftMatched.clear()
+    } else {
+      leftMatched = new BitSet(leftMatches.size)
+    }
+    if (rightMatches.size <= rightMatched.capacity) {
+      rightMatched.clear()
+    } else {
+      rightMatched = new BitSet(rightMatches.size)
+    }
+  }
+
+  /**
+   * Scan the left and right buffers for the next valid match.
+   *
+   * Note: this method mutates `joinedRow` to point to the latest matching rows in the buffers.
+   * If a left row has no valid matches on the right, or a right row has no valid matches on the
+   * left, then the row is joined with the null row and the result is considered a valid match.
+   *
+   * @return true if a valid match is found, false otherwise.
+   */
+  private def scanNextInBuffered(): Boolean = {
+    while (leftIndex < leftMatches.size) {
+      while (rightIndex < rightMatches.size) {
+        joinedRow(leftMatches(leftIndex), rightMatches(rightIndex))
+        if (boundCondition(joinedRow)) {
+          leftMatched.set(leftIndex)
+          rightMatched.set(rightIndex)
+          rightIndex += 1
+          return true
+        }
+        rightIndex += 1
+      }
+      rightIndex = 0
+      if (!leftMatched.get(leftIndex)) {
+        // the left row has never matched any right row, join it with null row
+        joinedRow(leftMatches(leftIndex), rightNullRow)
+        leftIndex += 1
+        return true
+      }
+      leftIndex += 1
+    }
+
+    while (rightIndex < rightMatches.size) {
+      if (!rightMatched.get(rightIndex)) {
+        // the right row has never matched any left row, join it with null row
+        joinedRow(leftNullRow, rightMatches(rightIndex))
+        rightIndex += 1
+        return true
+      }
+      rightIndex += 1
+    }
+
+    // There are no more valid matches in the left and right buffers
+    false
+  }
+
+  // --- Public methods --------------------------------------------------------------------------
+
+  def getJoinedRow(): JoinedRow = joinedRow
+
+  def advanceNext(): Boolean = {
+    // If we already buffered some matching rows, use them directly
+    if (leftIndex <= leftMatches.size || rightIndex <= rightMatches.size) {
+      if (scanNextInBuffered()) {
+        return true
+      }
+    }
+
+    if (leftRow != null && (leftRowKey.anyNull || rightRow == null)) {
+      joinedRow(leftRow.copy(), rightNullRow)
+      advancedLeft()
+      true
+    } else if (rightRow != null && (rightRowKey.anyNull || leftRow == null)) {
+      joinedRow(leftNullRow, rightRow.copy())
+      advancedRight()
+      true
+    } else if (leftRow != null && rightRow != null) {
+      // Both rows are present and neither have null values,
+      // so we populate the buffers with rows matching the next key
+      val comp = keyOrdering.compare(leftRowKey, rightRowKey)
+      if (comp <= 0) {
+        findMatchingRows(leftRowKey.copy())
+      } else {
+        findMatchingRows(rightRowKey.copy())
+      }
+      scanNextInBuffered()
+      true
+    } else {
+      // Both iterators have been consumed
+      false
+    }
+  }
+}
+
+private class FullOuterIterator(
+  smjScanner: SortMergeFullOuterJoinScanner,
+  resultProj: InternalRow => InternalRow,
+  numRows: LongSQLMetric
+) extends RowIterator {
+  private[this] val joinedRow: JoinedRow = smjScanner.getJoinedRow()
+
+  override def advanceNext(): Boolean = {
+    val r = smjScanner.advanceNext()
+    if (r) numRows += 1
+    r
+  }
+
+  override def getRow: InternalRow = resultProj(joinedRow)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/bbd887f5/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala
deleted file mode 100644
index 40a6c93..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala
+++ /dev/null
@@ -1,464 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.joins
-
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter}
-import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan}
-import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
-import org.apache.spark.util.collection.BitSet
-
-/**
- * Performs an sort merge outer join of two child relations.
- */
-case class SortMergeOuterJoin(
-    leftKeys: Seq[Expression],
-    rightKeys: Seq[Expression],
-    joinType: JoinType,
-    condition: Option[Expression],
-    left: SparkPlan,
-    right: SparkPlan) extends BinaryNode {
-
-  override private[sql] lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
-
-  override def output: Seq[Attribute] = {
-    joinType match {
-      case LeftOuter =>
-        left.output ++ right.output.map(_.withNullability(true))
-      case RightOuter =>
-        left.output.map(_.withNullability(true)) ++ right.output
-      case FullOuter =>
-        (left.output ++ right.output).map(_.withNullability(true))
-      case x =>
-        throw new IllegalArgumentException(
-          s"${getClass.getSimpleName} should not take $x as the JoinType")
-    }
-  }
-
-  override def outputPartitioning: Partitioning = joinType match {
-    // For left and right outer joins, the output is partitioned by the streamed input's join keys.
-    case LeftOuter => left.outputPartitioning
-    case RightOuter => right.outputPartitioning
-    case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
-    case x =>
-      throw new IllegalArgumentException(
-        s"${getClass.getSimpleName} should not take $x as the JoinType")
-  }
-
-  override def outputOrdering: Seq[SortOrder] = joinType match {
-    // For left and right outer joins, the output is ordered by the streamed input's join keys.
-    case LeftOuter => requiredOrders(leftKeys)
-    case RightOuter => requiredOrders(rightKeys)
-    // there are null rows in both streams, so there is no order
-    case FullOuter => Nil
-    case x => throw new IllegalArgumentException(
-      s"SortMergeOuterJoin should not take $x as the JoinType")
-  }
-
-  override def requiredChildDistribution: Seq[Distribution] =
-    ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
-
-  override def requiredChildOrdering: Seq[Seq[SortOrder]] =
-    requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil
-
-  private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = {
-    // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`.
-    keys.map(SortOrder(_, Ascending))
-  }
-
-  private def createLeftKeyGenerator(): Projection =
-    UnsafeProjection.create(leftKeys, left.output)
-
-  private def createRightKeyGenerator(): Projection =
-    UnsafeProjection.create(rightKeys, right.output)
-
-  override def doExecute(): RDD[InternalRow] = {
-    val numOutputRows = longMetric("numOutputRows")
-
-    left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
-      // An ordering that can be used to compare keys from both sides.
-      val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType))
-      val boundCondition: (InternalRow) => Boolean = {
-        condition.map { cond =>
-          newPredicate(cond, left.output ++ right.output)
-        }.getOrElse {
-          (r: InternalRow) => true
-        }
-      }
-      val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output)
-
-      joinType match {
-        case LeftOuter =>
-          val smjScanner = new SortMergeJoinScanner(
-            streamedKeyGenerator = createLeftKeyGenerator(),
-            bufferedKeyGenerator = createRightKeyGenerator(),
-            keyOrdering,
-            streamedIter = RowIterator.fromScala(leftIter),
-            bufferedIter = RowIterator.fromScala(rightIter)
-          )
-          val rightNullRow = new GenericInternalRow(right.output.length)
-          new LeftOuterIterator(
-            smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows).toScala
-
-        case RightOuter =>
-          val smjScanner = new SortMergeJoinScanner(
-            streamedKeyGenerator = createRightKeyGenerator(),
-            bufferedKeyGenerator = createLeftKeyGenerator(),
-            keyOrdering,
-            streamedIter = RowIterator.fromScala(rightIter),
-            bufferedIter = RowIterator.fromScala(leftIter)
-          )
-          val leftNullRow = new GenericInternalRow(left.output.length)
-          new RightOuterIterator(
-            smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows).toScala
-
-        case FullOuter =>
-          val leftNullRow = new GenericInternalRow(left.output.length)
-          val rightNullRow = new GenericInternalRow(right.output.length)
-          val smjScanner = new SortMergeFullOuterJoinScanner(
-            leftKeyGenerator = createLeftKeyGenerator(),
-            rightKeyGenerator = createRightKeyGenerator(),
-            keyOrdering,
-            leftIter = RowIterator.fromScala(leftIter),
-            rightIter = RowIterator.fromScala(rightIter),
-            boundCondition,
-            leftNullRow,
-            rightNullRow)
-
-          new FullOuterIterator(
-            smjScanner,
-            resultProj,
-            numOutputRows).toScala
-
-        case x =>
-          throw new IllegalArgumentException(
-            s"SortMergeOuterJoin should not take $x as the JoinType")
-      }
-    }
-  }
-}
-
-/**
- * An iterator for outputting rows in left outer join.
- */
-private class LeftOuterIterator(
-    smjScanner: SortMergeJoinScanner,
-    rightNullRow: InternalRow,
-    boundCondition: InternalRow => Boolean,
-    resultProj: InternalRow => InternalRow,
-    numOutputRows: LongSQLMetric)
-  extends OneSideOuterIterator(
-    smjScanner, rightNullRow, boundCondition, resultProj, numOutputRows) {
-
-  protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row)
-  protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withRight(row)
-}
-
-/**
- * An iterator for outputting rows in right outer join.
- */
-private class RightOuterIterator(
-    smjScanner: SortMergeJoinScanner,
-    leftNullRow: InternalRow,
-    boundCondition: InternalRow => Boolean,
-    resultProj: InternalRow => InternalRow,
-    numOutputRows: LongSQLMetric)
-  extends OneSideOuterIterator(
-    smjScanner, leftNullRow, boundCondition, resultProj, numOutputRows) {
-
-  protected override def setStreamSideOutput(row: InternalRow): Unit = joinedRow.withRight(row)
-  protected override def setBufferedSideOutput(row: InternalRow): Unit = joinedRow.withLeft(row)
-}
-
-/**
- * An abstract iterator for sharing code between [[LeftOuterIterator]] and [[RightOuterIterator]].
- *
- * Each [[OneSideOuterIterator]] has a streamed side and a buffered side. Each row on the
- * streamed side will output 0 or many rows, one for each matching row on the buffered side.
- * If there are no matches, then the buffered side of the joined output will be a null row.
- *
- * In left outer join, the left is the streamed side and the right is the buffered side.
- * In right outer join, the right is the streamed side and the left is the buffered side.
- *
- * @param smjScanner a scanner that streams rows and buffers any matching rows
- * @param bufferedSideNullRow the default row to return when a streamed row has no matches
- * @param boundCondition an additional filter condition for buffered rows
- * @param resultProj how the output should be projected
- * @param numOutputRows an accumulator metric for the number of rows output
- */
-private abstract class OneSideOuterIterator(
-    smjScanner: SortMergeJoinScanner,
-    bufferedSideNullRow: InternalRow,
-    boundCondition: InternalRow => Boolean,
-    resultProj: InternalRow => InternalRow,
-    numOutputRows: LongSQLMetric) extends RowIterator {
-
-  // A row to store the joined result, reused many times
-  protected[this] val joinedRow: JoinedRow = new JoinedRow()
-
-  // Index of the buffered rows, reset to 0 whenever we advance to a new streamed row
-  private[this] var bufferIndex: Int = 0
-
-  // This iterator is initialized lazily so there should be no matches initially
-  assert(smjScanner.getBufferedMatches.length == 0)
-
-  // Set output methods to be overridden by subclasses
-  protected def setStreamSideOutput(row: InternalRow): Unit
-  protected def setBufferedSideOutput(row: InternalRow): Unit
-
-  /**
-   * Advance to the next row on the stream side and populate the buffer with matches.
-   * @return whether there are more rows in the stream to consume.
-   */
-  private def advanceStream(): Boolean = {
-    bufferIndex = 0
-    if (smjScanner.findNextOuterJoinRows()) {
-      setStreamSideOutput(smjScanner.getStreamedRow)
-      if (smjScanner.getBufferedMatches.isEmpty) {
-        // There are no matching rows in the buffer, so return the null row
-        setBufferedSideOutput(bufferedSideNullRow)
-      } else {
-        // Find the next row in the buffer that satisfied the bound condition
-        if (!advanceBufferUntilBoundConditionSatisfied()) {
-          setBufferedSideOutput(bufferedSideNullRow)
-        }
-      }
-      true
-    } else {
-      // Stream has been exhausted
-      false
-    }
-  }
-
-  /**
-   * Advance to the next row in the buffer that satisfies the bound condition.
-   * @return whether there is such a row in the current buffer.
-   */
-  private def advanceBufferUntilBoundConditionSatisfied(): Boolean = {
-    var foundMatch: Boolean = false
-    while (!foundMatch && bufferIndex < smjScanner.getBufferedMatches.length) {
-      setBufferedSideOutput(smjScanner.getBufferedMatches(bufferIndex))
-      foundMatch = boundCondition(joinedRow)
-      bufferIndex += 1
-    }
-    foundMatch
-  }
-
-  override def advanceNext(): Boolean = {
-    val r = advanceBufferUntilBoundConditionSatisfied() || advanceStream()
-    if (r) numOutputRows += 1
-    r
-  }
-
-  override def getRow: InternalRow = resultProj(joinedRow)
-}
-
-private class SortMergeFullOuterJoinScanner(
-    leftKeyGenerator: Projection,
-    rightKeyGenerator: Projection,
-    keyOrdering: Ordering[InternalRow],
-    leftIter: RowIterator,
-    rightIter: RowIterator,
-    boundCondition: InternalRow => Boolean,
-    leftNullRow: InternalRow,
-    rightNullRow: InternalRow)  {
-  private[this] val joinedRow: JoinedRow = new JoinedRow()
-  private[this] var leftRow: InternalRow = _
-  private[this] var leftRowKey: InternalRow = _
-  private[this] var rightRow: InternalRow = _
-  private[this] var rightRowKey: InternalRow = _
-
-  private[this] var leftIndex: Int = 0
-  private[this] var rightIndex: Int = 0
-  private[this] val leftMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow]
-  private[this] val rightMatches: ArrayBuffer[InternalRow] = new ArrayBuffer[InternalRow]
-  private[this] var leftMatched: BitSet = new BitSet(1)
-  private[this] var rightMatched: BitSet = new BitSet(1)
-
-  advancedLeft()
-  advancedRight()
-
-  // --- Private methods --------------------------------------------------------------------------
-
-  /**
-   * Advance the left iterator and compute the new row's join key.
-   * @return true if the left iterator returned a row and false otherwise.
-   */
-  private def advancedLeft(): Boolean = {
-    if (leftIter.advanceNext()) {
-      leftRow = leftIter.getRow
-      leftRowKey = leftKeyGenerator(leftRow)
-      true
-    } else {
-      leftRow = null
-      leftRowKey = null
-      false
-    }
-  }
-
-  /**
-   * Advance the right iterator and compute the new row's join key.
-   * @return true if the right iterator returned a row and false otherwise.
-   */
-  private def advancedRight(): Boolean = {
-    if (rightIter.advanceNext()) {
-      rightRow = rightIter.getRow
-      rightRowKey = rightKeyGenerator(rightRow)
-      true
-    } else {
-      rightRow = null
-      rightRowKey = null
-      false
-    }
-  }
-
-  /**
-   * Populate the left and right buffers with rows matching the provided key.
-   * This consumes rows from both iterators until their keys are different from the matching key.
-   */
-  private def findMatchingRows(matchingKey: InternalRow): Unit = {
-    leftMatches.clear()
-    rightMatches.clear()
-    leftIndex = 0
-    rightIndex = 0
-
-    while (leftRowKey != null && keyOrdering.compare(leftRowKey, matchingKey) == 0) {
-      leftMatches += leftRow.copy()
-      advancedLeft()
-    }
-    while (rightRowKey != null && keyOrdering.compare(rightRowKey, matchingKey) == 0) {
-      rightMatches += rightRow.copy()
-      advancedRight()
-    }
-
-    if (leftMatches.size <= leftMatched.capacity) {
-      leftMatched.clear()
-    } else {
-      leftMatched = new BitSet(leftMatches.size)
-    }
-    if (rightMatches.size <= rightMatched.capacity) {
-      rightMatched.clear()
-    } else {
-      rightMatched = new BitSet(rightMatches.size)
-    }
-  }
-
-  /**
-   * Scan the left and right buffers for the next valid match.
-   *
-   * Note: this method mutates `joinedRow` to point to the latest matching rows in the buffers.
-   * If a left row has no valid matches on the right, or a right row has no valid matches on the
-   * left, then the row is joined with the null row and the result is considered a valid match.
-   *
-   * @return true if a valid match is found, false otherwise.
-   */
-  private def scanNextInBuffered(): Boolean = {
-    while (leftIndex < leftMatches.size) {
-      while (rightIndex < rightMatches.size) {
-        joinedRow(leftMatches(leftIndex), rightMatches(rightIndex))
-        if (boundCondition(joinedRow)) {
-          leftMatched.set(leftIndex)
-          rightMatched.set(rightIndex)
-          rightIndex += 1
-          return true
-        }
-        rightIndex += 1
-      }
-      rightIndex = 0
-      if (!leftMatched.get(leftIndex)) {
-        // the left row has never matched any right row, join it with null row
-        joinedRow(leftMatches(leftIndex), rightNullRow)
-        leftIndex += 1
-        return true
-      }
-      leftIndex += 1
-    }
-
-    while (rightIndex < rightMatches.size) {
-      if (!rightMatched.get(rightIndex)) {
-        // the right row has never matched any left row, join it with null row
-        joinedRow(leftNullRow, rightMatches(rightIndex))
-        rightIndex += 1
-        return true
-      }
-      rightIndex += 1
-    }
-
-    // There are no more valid matches in the left and right buffers
-    false
-  }
-
-  // --- Public methods --------------------------------------------------------------------------
-
-  def getJoinedRow(): JoinedRow = joinedRow
-
-  def advanceNext(): Boolean = {
-    // If we already buffered some matching rows, use them directly
-    if (leftIndex <= leftMatches.size || rightIndex <= rightMatches.size) {
-      if (scanNextInBuffered()) {
-        return true
-      }
-    }
-
-    if (leftRow != null && (leftRowKey.anyNull || rightRow == null)) {
-      joinedRow(leftRow.copy(), rightNullRow)
-      advancedLeft()
-      true
-    } else if (rightRow != null && (rightRowKey.anyNull || leftRow == null)) {
-      joinedRow(leftNullRow, rightRow.copy())
-      advancedRight()
-      true
-    } else if (leftRow != null && rightRow != null) {
-      // Both rows are present and neither have null values,
-      // so we populate the buffers with rows matching the next key
-      val comp = keyOrdering.compare(leftRowKey, rightRowKey)
-      if (comp <= 0) {
-        findMatchingRows(leftRowKey.copy())
-      } else {
-        findMatchingRows(rightRowKey.copy())
-      }
-      scanNextInBuffered()
-      true
-    } else {
-      // Both iterators have been consumed
-      false
-    }
-  }
-}
-
-private class FullOuterIterator(
-    smjScanner: SortMergeFullOuterJoinScanner,
-    resultProj: InternalRow => InternalRow,
-    numRows: LongSQLMetric
-  ) extends RowIterator {
-  private[this] val joinedRow: JoinedRow = smjScanner.getJoinedRow()
-
-  override def advanceNext(): Boolean = {
-    val r = smjScanner.advanceNext()
-    if (r) numRows += 1
-    r
-  }
-
-  override def getRow: InternalRow = resultProj(joinedRow)
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/bbd887f5/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 50647c2..580e8d8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -51,7 +51,6 @@ class JoinSuite extends QueryTest with SharedSQLContext {
       case j: BroadcastNestedLoopJoin => j
       case j: BroadcastLeftSemiJoinHash => j
       case j: SortMergeJoin => j
-      case j: SortMergeOuterJoin => j
     }
 
     assert(operators.size === 1)
@@ -83,13 +82,13 @@ class JoinSuite extends QueryTest with SharedSQLContext {
         ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]),
         ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]),
         ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]),
-        ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeOuterJoin]),
+        ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[SortMergeJoin]),
         ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
           classOf[SortMergeJoin]),
         ("SELECT * FROM testData right join testData2 ON key = a and key = 2",
-          classOf[SortMergeOuterJoin]),
+          classOf[SortMergeJoin]),
         ("SELECT * FROM testData full outer join testData2 ON key = a",
-          classOf[SortMergeOuterJoin]),
+          classOf[SortMergeJoin]),
         ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)",
           classOf[BroadcastNestedLoopJoin]),
         ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)",

http://git-wip-us.apache.org/repos/asf/spark/blob/bbd887f5/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 88fbcda..9cd50ab 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{execution, Row}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder}
+import org.apache.spark.sql.catalyst.plans.Inner
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.columnar.InMemoryRelation
@@ -487,6 +488,7 @@ class PlannerSuite extends SharedSQLContext {
     val inputPlan = SortMergeJoin(
         Literal(1) :: Nil,
         Literal(1) :: Nil,
+        Inner,
         None,
         shuffle,
         shuffle)
@@ -503,6 +505,7 @@ class PlannerSuite extends SharedSQLContext {
     val inputPlan2 = SortMergeJoin(
       Literal(1) :: Nil,
       Literal(1) :: Nil,
+      Inner,
       None,
       ShuffleExchange(finalPartitioning, inputPlan),
       ShuffleExchange(finalPartitioning, inputPlan))

http://git-wip-us.apache.org/repos/asf/spark/blob/bbd887f5/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index eeb4440..814e25d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -108,7 +108,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
         leftPlan: SparkPlan,
         rightPlan: SparkPlan) = {
       val sortMergeJoin =
-        joins.SortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan)
+        joins.SortMergeJoin(leftKeys, rightKeys, Inner, boundCondition, leftPlan, rightPlan)
       EnsureRequirements(sqlContext.sessionState.conf).apply(sortMergeJoin)
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/bbd887f5/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index 4525486..547d062 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -94,12 +94,12 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext {
       }
     }
 
-    test(s"$testName using SortMergeOuterJoin") {
+    test(s"$testName using SortMergeJoin") {
       extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
         withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
           checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
             EnsureRequirements(sqlContext.sessionState.conf).apply(
-              SortMergeOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
+              SortMergeJoin(leftKeys, rightKeys, joinType, boundCondition, left, right)),
             expectedAnswer.map(Row.fromTuple),
             sortAnswers = true)
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/bbd887f5/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index f754acb..d7bd215 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -179,18 +179,18 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
     }
   }
 
-  test("SortMergeOuterJoin metrics") {
-    // Because SortMergeOuterJoin may skip different rows if the number of partitions is different,
+  test("SortMergeJoin(outer) metrics") {
+    // Because SortMergeJoin may skip different rows if the number of partitions is different,
     // this test should use the deterministic number of partitions.
     val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
     testDataForJoin.registerTempTable("testDataForJoin")
     withTempTable("testDataForJoin") {
       // Assume the execution plan is
-      // ... -> SortMergeOuterJoin(nodeId = 1) -> TungstenProject(nodeId = 0)
+      // ... -> SortMergeJoin(nodeId = 1) -> TungstenProject(nodeId = 0)
       val df = sqlContext.sql(
         "SELECT * FROM testData2 left JOIN testDataForJoin ON testData2.a = testDataForJoin.a")
       testSparkPlanMetrics(df, 1, Map(
-        0L -> ("SortMergeOuterJoin", Map(
+        0L -> ("SortMergeJoin", Map(
           // It's 4 because we only read 3 rows in the first partition and 1 row in the second one
           "number of output rows" -> 8L)))
       )
@@ -198,7 +198,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
       val df2 = sqlContext.sql(
         "SELECT * FROM testDataForJoin right JOIN testData2 ON testData2.a = testDataForJoin.a")
       testSparkPlanMetrics(df2, 1, Map(
-        0L -> ("SortMergeOuterJoin", Map(
+        0L -> ("SortMergeJoin", Map(
           // It's 4 because we only read 3 rows in the first partition and 1 row in the second one
           "number of output rows" -> 8L)))
       )


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org