You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/07/06 23:33:36 UTC

spark git commit: [SPARK-4485] [SQL] 1) Add broadcast hash outer join, (2) Fix SparkPlanTest

Repository: spark
Updated Branches:
  refs/heads/master 37e4d9214 -> 2471c0bf7


[SPARK-4485] [SQL] 1) Add broadcast hash outer join, (2) Fix SparkPlanTest

This pull request
(1) extracts common functions used by hash outer joins and put it in interface HashOuterJoin
(2) adds ShuffledHashOuterJoin and BroadcastHashOuterJoin
(3) adds test cases for shuffled and broadcast hash outer join
(3) makes SparkPlanTest to support binary or more complex operators, and fixes bugs in plan composition in SparkPlanTest

Author: kai <ka...@eecs.berkeley.edu>

Closes #7162 from kai-zeng/outer and squashes the following commits:

3742359 [kai] Fix not-serializable exception for code-generated keys in broadcasted relations
14e4bf8 [kai] Use CanBroadcast in broadcast outer join planning
dc5127e [kai] code style fixes
b5a4efa [kai] (1) Add broadcast hash outer join, (2) Fix SparkPlanTest


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2471c0bf
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2471c0bf
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2471c0bf

Branch: refs/heads/master
Commit: 2471c0bf7f463bb144b44a2e51c0f363e71e099d
Parents: 37e4d92
Author: kai <ka...@eecs.berkeley.edu>
Authored: Mon Jul 6 14:33:30 2015 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Mon Jul 6 14:33:30 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/execution/SparkStrategies.scala   |  12 +-
 .../joins/BroadcastHashOuterJoin.scala          | 121 +++++++++++++++++++
 .../sql/execution/joins/HashOuterJoin.scala     |  95 ++++-----------
 .../execution/joins/ShuffledHashOuterJoin.scala |  85 +++++++++++++
 .../scala/org/apache/spark/sql/JoinSuite.scala  |  40 +++++-
 .../spark/sql/execution/SparkPlanTest.scala     |  99 ++++++++++++---
 .../sql/execution/joins/OuterJoinSuite.scala    |  88 ++++++++++++++
 7 files changed, 441 insertions(+), 99 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2471c0bf/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 5daf86d..3204498 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -117,8 +117,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
           leftKeys, rightKeys, buildSide, planLater(left), planLater(right))
         condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil
 
+      case ExtractEquiJoinKeys(
+             LeftOuter, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
+        joins.BroadcastHashOuterJoin(
+          leftKeys, rightKeys, LeftOuter, condition, planLater(left), planLater(right)) :: Nil
+
+      case ExtractEquiJoinKeys(
+             RightOuter, leftKeys, rightKeys, condition, CanBroadcast(left), right) =>
+        joins.BroadcastHashOuterJoin(
+          leftKeys, rightKeys, RightOuter, condition, planLater(left), planLater(right)) :: Nil
+
       case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) =>
-        joins.HashOuterJoin(
+        joins.ShuffledHashOuterJoin(
           leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
 
       case _ => Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/2471c0bf/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
new file mode 100644
index 0000000..5da04c7
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution}
+import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
+import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+import org.apache.spark.util.ThreadUtils
+
+import scala.collection.JavaConversions._
+import scala.concurrent._
+import scala.concurrent.duration._
+
+/**
+ * :: DeveloperApi ::
+ * Performs a outer hash join for two child relations.  When the output RDD of this operator is
+ * being constructed, a Spark job is asynchronously started to calculate the values for the
+ * broadcasted relation.  This data is then placed in a Spark broadcast variable.  The streamed
+ * relation is not shuffled.
+ */
+@DeveloperApi
+case class BroadcastHashOuterJoin(
+    leftKeys: Seq[Expression],
+    rightKeys: Seq[Expression],
+    joinType: JoinType,
+    condition: Option[Expression],
+    left: SparkPlan,
+    right: SparkPlan) extends BinaryNode with HashOuterJoin {
+
+  val timeout = {
+    val timeoutValue = sqlContext.conf.broadcastTimeout
+    if (timeoutValue < 0) {
+      Duration.Inf
+    } else {
+      timeoutValue.seconds
+    }
+  }
+
+  override def requiredChildDistribution: Seq[Distribution] =
+    UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
+
+  private[this] lazy val (buildPlan, streamedPlan) = joinType match {
+    case RightOuter => (left, right)
+    case LeftOuter => (right, left)
+    case x =>
+      throw new IllegalArgumentException(
+        s"BroadcastHashOuterJoin should not take $x as the JoinType")
+  }
+
+  private[this] lazy val (buildKeys, streamedKeys) = joinType match {
+    case RightOuter => (leftKeys, rightKeys)
+    case LeftOuter => (rightKeys, leftKeys)
+    case x =>
+      throw new IllegalArgumentException(
+        s"BroadcastHashOuterJoin should not take $x as the JoinType")
+  }
+
+  @transient
+  private val broadcastFuture = future {
+    // Note that we use .execute().collect() because we don't want to convert data to Scala types
+    val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
+    // buildHashTable uses code-generated rows as keys, which are not serializable
+    val hashed =
+      buildHashTable(input.iterator, new InterpretedProjection(buildKeys, buildPlan.output))
+    sparkContext.broadcast(hashed)
+  }(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext)
+
+  override def doExecute(): RDD[InternalRow] = {
+    val broadcastRelation = Await.result(broadcastFuture, timeout)
+
+    streamedPlan.execute().mapPartitions { streamedIter =>
+      val joinedRow = new JoinedRow()
+      val hashTable = broadcastRelation.value
+      val keyGenerator = newProjection(streamedKeys, streamedPlan.output)
+
+      joinType match {
+        case LeftOuter =>
+          streamedIter.flatMap(currentRow => {
+            val rowKey = keyGenerator(currentRow)
+            joinedRow.withLeft(currentRow)
+            leftOuterIterator(rowKey, joinedRow, hashTable.getOrElse(rowKey, EMPTY_LIST))
+          })
+
+        case RightOuter =>
+          streamedIter.flatMap(currentRow => {
+            val rowKey = keyGenerator(currentRow)
+            joinedRow.withRight(currentRow)
+            rightOuterIterator(rowKey, hashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow)
+          })
+
+        case x =>
+          throw new IllegalArgumentException(
+            s"BroadcastHashOuterJoin should not take $x as the JoinType")
+      }
+    }
+  }
+}
+
+object BroadcastHashOuterJoin {
+
+  private val broadcastHashOuterJoinExecutionContext = ExecutionContext.fromExecutorService(
+    ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-outer-join", 128))
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/2471c0bf/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
index e41538e..886b5fa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala
@@ -19,32 +19,25 @@ package org.apache.spark.sql.execution.joins
 
 import java.util.{HashMap => JavaHashMap}
 
-import org.apache.spark.rdd.RDD
-
-import scala.collection.JavaConversions._
-
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning, UnknownPartitioning}
+import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
 import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter}
-import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.util.collection.CompactBuffer
 
-/**
- * :: DeveloperApi ::
- * Performs a hash based outer join for two child relations by shuffling the data using
- * the join keys. This operator requires loading the associated partition in both side into memory.
- */
 @DeveloperApi
-case class HashOuterJoin(
-    leftKeys: Seq[Expression],
-    rightKeys: Seq[Expression],
-    joinType: JoinType,
-    condition: Option[Expression],
-    left: SparkPlan,
-    right: SparkPlan) extends BinaryNode {
-
-  override def outputPartitioning: Partitioning = joinType match {
+trait HashOuterJoin {
+  self: SparkPlan =>
+
+  val leftKeys: Seq[Expression]
+  val rightKeys: Seq[Expression]
+  val joinType: JoinType
+  val condition: Option[Expression]
+  val left: SparkPlan
+  val right: SparkPlan
+
+override def outputPartitioning: Partitioning = joinType match {
     case LeftOuter => left.outputPartitioning
     case RightOuter => right.outputPartitioning
     case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
@@ -52,9 +45,6 @@ case class HashOuterJoin(
       throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType")
   }
 
-  override def requiredChildDistribution: Seq[ClusteredDistribution] =
-    ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
-
   override def output: Seq[Attribute] = {
     joinType match {
       case LeftOuter =>
@@ -68,8 +58,8 @@ case class HashOuterJoin(
     }
   }
 
-  @transient private[this] lazy val DUMMY_LIST = Seq[InternalRow](null)
-  @transient private[this] lazy val EMPTY_LIST = Seq.empty[InternalRow]
+  @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null)
+  @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]()
 
   @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length)
   @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length)
@@ -80,7 +70,7 @@ case class HashOuterJoin(
   // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
   // iterator for performance purpose.
 
-  private[this] def leftOuterIterator(
+  protected[this] def leftOuterIterator(
       key: InternalRow,
       joinedRow: JoinedRow,
       rightIter: Iterable[InternalRow]): Iterator[InternalRow] = {
@@ -89,7 +79,7 @@ case class HashOuterJoin(
         val temp = rightIter.collect {
           case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy()
         }
-        if (temp.size == 0) {
+        if (temp.isEmpty) {
           joinedRow.withRight(rightNullRow).copy :: Nil
         } else {
           temp
@@ -101,18 +91,17 @@ case class HashOuterJoin(
     ret.iterator
   }
 
-  private[this] def rightOuterIterator(
+  protected[this] def rightOuterIterator(
       key: InternalRow,
       leftIter: Iterable[InternalRow],
       joinedRow: JoinedRow): Iterator[InternalRow] = {
-
     val ret: Iterable[InternalRow] = {
       if (!key.anyNull) {
         val temp = leftIter.collect {
           case l if boundCondition(joinedRow.withLeft(l)) =>
-            joinedRow.copy
+            joinedRow.copy()
         }
-        if (temp.size == 0) {
+        if (temp.isEmpty) {
           joinedRow.withLeft(leftNullRow).copy :: Nil
         } else {
           temp
@@ -124,10 +113,9 @@ case class HashOuterJoin(
     ret.iterator
   }
 
-  private[this] def fullOuterIterator(
+  protected[this] def fullOuterIterator(
       key: InternalRow, leftIter: Iterable[InternalRow], rightIter: Iterable[InternalRow],
       joinedRow: JoinedRow): Iterator[InternalRow] = {
-
     if (!key.anyNull) {
       // Store the positions of records in right, if one of its associated row satisfy
       // the join condition.
@@ -171,7 +159,7 @@ case class HashOuterJoin(
     }
   }
 
-  private[this] def buildHashTable(
+  protected[this] def buildHashTable(
       iter: Iterator[InternalRow],
       keyGenerator: Projection): JavaHashMap[InternalRow, CompactBuffer[InternalRow]] = {
     val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]]()
@@ -190,43 +178,4 @@ case class HashOuterJoin(
 
     hashTable
   }
-
-  protected override def doExecute(): RDD[InternalRow] = {
-    val joinedRow = new JoinedRow()
-    left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
-      // TODO this probably can be replaced by external sort (sort merged join?)
-
-      joinType match {
-        case LeftOuter =>
-          val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))
-          val keyGenerator = newProjection(leftKeys, left.output)
-          leftIter.flatMap( currentRow => {
-            val rowKey = keyGenerator(currentRow)
-            joinedRow.withLeft(currentRow)
-            leftOuterIterator(rowKey, joinedRow, rightHashTable.getOrElse(rowKey, EMPTY_LIST))
-          })
-
-        case RightOuter =>
-          val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output))
-          val keyGenerator = newProjection(rightKeys, right.output)
-          rightIter.flatMap ( currentRow => {
-            val rowKey = keyGenerator(currentRow)
-            joinedRow.withRight(currentRow)
-            rightOuterIterator(rowKey, leftHashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow)
-          })
-
-        case FullOuter =>
-          val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output))
-          val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))
-          (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key =>
-            fullOuterIterator(key,
-              leftHashTable.getOrElse(key, EMPTY_LIST),
-              rightHashTable.getOrElse(key, EMPTY_LIST), joinedRow)
-          }
-
-        case x =>
-          throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType")
-      }
-    }
-  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2471c0bf/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
new file mode 100644
index 0000000..cfc9c14
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.{Distribution, ClusteredDistribution}
+import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter}
+import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+
+import scala.collection.JavaConversions._
+
+/**
+ * :: DeveloperApi ::
+ * Performs a hash based outer join for two child relations by shuffling the data using
+ * the join keys. This operator requires loading the associated partition in both side into memory.
+ */
+@DeveloperApi
+case class ShuffledHashOuterJoin(
+    leftKeys: Seq[Expression],
+    rightKeys: Seq[Expression],
+    joinType: JoinType,
+    condition: Option[Expression],
+    left: SparkPlan,
+    right: SparkPlan) extends BinaryNode with HashOuterJoin {
+
+  override def requiredChildDistribution: Seq[Distribution] =
+    ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+
+  protected override def doExecute(): RDD[InternalRow] = {
+    val joinedRow = new JoinedRow()
+    left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
+      // TODO this probably can be replaced by external sort (sort merged join?)
+      joinType match {
+        case LeftOuter =>
+          val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))
+          val keyGenerator = newProjection(leftKeys, left.output)
+          leftIter.flatMap( currentRow => {
+            val rowKey = keyGenerator(currentRow)
+            joinedRow.withLeft(currentRow)
+            leftOuterIterator(rowKey, joinedRow, rightHashTable.getOrElse(rowKey, EMPTY_LIST))
+          })
+
+        case RightOuter =>
+          val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output))
+          val keyGenerator = newProjection(rightKeys, right.output)
+          rightIter.flatMap ( currentRow => {
+            val rowKey = keyGenerator(currentRow)
+            joinedRow.withRight(currentRow)
+            rightOuterIterator(rowKey, leftHashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow)
+          })
+
+        case FullOuter =>
+          val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output))
+          val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))
+          (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key =>
+            fullOuterIterator(key,
+              leftHashTable.getOrElse(key, EMPTY_LIST),
+              rightHashTable.getOrElse(key, EMPTY_LIST),
+              joinedRow)
+          }
+
+        case x =>
+          throw new IllegalArgumentException(
+            s"ShuffledHashOuterJoin should not take $x as the JoinType")
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/2471c0bf/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 20390a5..8953889 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -45,9 +45,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
     val physical = df.queryExecution.sparkPlan
     val operators = physical.collect {
       case j: ShuffledHashJoin => j
-      case j: HashOuterJoin => j
+      case j: ShuffledHashOuterJoin => j
       case j: LeftSemiJoinHash => j
       case j: BroadcastHashJoin => j
+      case j: BroadcastHashOuterJoin => j
       case j: LeftSemiJoinBNL => j
       case j: CartesianProduct => j
       case j: BroadcastNestedLoopJoin => j
@@ -81,12 +82,13 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
       ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]),
       ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[ShuffledHashJoin]),
       ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashJoin]),
-      ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[HashOuterJoin]),
+      ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]),
       ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
-        classOf[HashOuterJoin]),
+        classOf[ShuffledHashOuterJoin]),
       ("SELECT * FROM testData right join testData2 ON key = a and key = 2",
-        classOf[HashOuterJoin]),
-      ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[HashOuterJoin]),
+        classOf[ShuffledHashOuterJoin]),
+      ("SELECT * FROM testData full outer join testData2 ON key = a",
+        classOf[ShuffledHashOuterJoin]),
       ("SELECT * FROM testData left JOIN testData2 ON (key * a != key + a)",
         classOf[BroadcastNestedLoopJoin]),
       ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)",
@@ -133,6 +135,34 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
     ctx.sql("UNCACHE TABLE testData")
   }
 
+  test("broadcasted hash outer join operator selection") {
+    ctx.cacheManager.clearCache()
+    ctx.sql("CACHE TABLE testData")
+
+    val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled
+    Seq(
+      ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]),
+      ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
+        classOf[BroadcastHashOuterJoin]),
+      ("SELECT * FROM testData right join testData2 ON key = a and key = 2",
+        classOf[BroadcastHashOuterJoin])
+    ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+    try {
+      ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true)
+      Seq(
+        ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[ShuffledHashOuterJoin]),
+        ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
+          classOf[BroadcastHashOuterJoin]),
+        ("SELECT * FROM testData right join testData2 ON key = a and key = 2",
+          classOf[BroadcastHashOuterJoin])
+      ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+    } finally {
+      ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED)
+    }
+
+    ctx.sql("UNCACHE TABLE testData")
+  }
+
   test("multiple-key equi-join is hash-join") {
     val x = testData2.as("x")
     val y = testData2.as("y")

http://git-wip-us.apache.org/repos/asf/spark/blob/2471c0bf/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
index 13f3be8..108b112 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
@@ -54,6 +54,37 @@ class SparkPlanTest extends SparkFunSuite {
       input: DataFrame,
       planFunction: SparkPlan => SparkPlan,
       expectedAnswer: Seq[Row]): Unit = {
+    checkAnswer(input :: Nil, (plans: Seq[SparkPlan]) => planFunction(plans.head), expectedAnswer)
+  }
+
+  /**
+   * Runs the plan and makes sure the answer matches the expected result.
+   * @param left the left input data to be used.
+   * @param right the right input data to be used.
+   * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
+   *                     the physical operator that's being tested.
+   * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+   */
+  protected def checkAnswer(
+      left: DataFrame,
+      right: DataFrame,
+      planFunction: (SparkPlan, SparkPlan) => SparkPlan,
+      expectedAnswer: Seq[Row]): Unit = {
+    checkAnswer(left :: right :: Nil,
+      (plans: Seq[SparkPlan]) => planFunction(plans(0), plans(1)), expectedAnswer)
+  }
+
+  /**
+   * Runs the plan and makes sure the answer matches the expected result.
+   * @param input the input data to be used.
+   * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
+   *                     the physical operator that's being tested.
+   * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+   */
+  protected def checkAnswer(
+      input: Seq[DataFrame],
+      planFunction: Seq[SparkPlan] => SparkPlan,
+      expectedAnswer: Seq[Row]): Unit = {
     SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer) match {
       case Some(errorMessage) => fail(errorMessage)
       case None =>
@@ -72,11 +103,41 @@ class SparkPlanTest extends SparkFunSuite {
       planFunction: SparkPlan => SparkPlan,
       expectedAnswer: Seq[A]): Unit = {
     val expectedRows = expectedAnswer.map(Row.fromTuple)
-    SparkPlanTest.checkAnswer(input, planFunction, expectedRows) match {
-      case Some(errorMessage) => fail(errorMessage)
-      case None =>
-    }
+    checkAnswer(input, planFunction, expectedRows)
+  }
+
+  /**
+   * Runs the plan and makes sure the answer matches the expected result.
+   * @param left the left input data to be used.
+   * @param right the right input data to be used.
+   * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
+   *                     the physical operator that's being tested.
+   * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s.
+   */
+  protected def checkAnswer[A <: Product : TypeTag](
+      left: DataFrame,
+      right: DataFrame,
+      planFunction: (SparkPlan, SparkPlan) => SparkPlan,
+      expectedAnswer: Seq[A]): Unit = {
+    val expectedRows = expectedAnswer.map(Row.fromTuple)
+    checkAnswer(left, right, planFunction, expectedRows)
+  }
+
+  /**
+   * Runs the plan and makes sure the answer matches the expected result.
+   * @param input the input data to be used.
+   * @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
+   *                     the physical operator that's being tested.
+   * @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s.
+   */
+  protected def checkAnswer[A <: Product : TypeTag](
+      input: Seq[DataFrame],
+      planFunction: Seq[SparkPlan] => SparkPlan,
+      expectedAnswer: Seq[A]): Unit = {
+    val expectedRows = expectedAnswer.map(Row.fromTuple)
+    checkAnswer(input, planFunction, expectedRows)
   }
+
 }
 
 /**
@@ -92,27 +153,25 @@ object SparkPlanTest {
    * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
    */
   def checkAnswer(
-      input: DataFrame,
-      planFunction: SparkPlan => SparkPlan,
+      input: Seq[DataFrame],
+      planFunction: Seq[SparkPlan] => SparkPlan,
       expectedAnswer: Seq[Row]): Option[String] = {
 
-    val outputPlan = planFunction(input.queryExecution.sparkPlan)
+    val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan))
 
     // A very simple resolver to make writing tests easier. In contrast to the real resolver
     // this is always case sensitive and does not try to handle scoping or complex type resolution.
-    val resolvedPlan = outputPlan transform {
-      case plan: SparkPlan =>
-        val inputMap = plan.children.flatMap(_.output).zipWithIndex.map {
-          case (a, i) =>
-            (a.name, BoundReference(i, a.dataType, a.nullable))
-        }.toMap
-
-        plan.transformExpressions {
-          case UnresolvedAttribute(Seq(u)) =>
-            inputMap.getOrElse(u,
-              sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
-        }
-    }
+    val resolvedPlan = TestSQLContext.prepareForExecution.execute(
+      outputPlan transform {
+        case plan: SparkPlan =>
+          val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
+          plan.transformExpressions {
+            case UnresolvedAttribute(Seq(u)) =>
+              inputMap.getOrElse(u,
+                sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
+          }
+      }
+    )
 
     def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
       // Converts data to types that we can do equality comparison using Scala collections.

http://git-wip-us.apache.org/repos/asf/spark/blob/2471c0bf/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
new file mode 100644
index 0000000..5707d2f
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.{Expression, LessThan}
+import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter}
+import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}
+
+class OuterJoinSuite extends SparkPlanTest {
+
+  val left = Seq(
+    (1, 2.0),
+    (2, 1.0),
+    (3, 3.0)
+  ).toDF("a", "b")
+
+  val right = Seq(
+    (2, 3.0),
+    (3, 2.0),
+    (4, 1.0)
+  ).toDF("c", "d")
+
+  val leftKeys: List[Expression] = 'a :: Nil
+  val rightKeys: List[Expression] = 'c :: Nil
+  val condition = Some(LessThan('b, 'd))
+
+  test("shuffled hash outer join") {
+    checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
+      ShuffledHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right),
+      Seq(
+        (1, 2.0, null, null),
+        (2, 1.0, 2, 3.0),
+        (3, 3.0, null, null)
+      ))
+
+    checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
+      ShuffledHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right),
+      Seq(
+        (2, 1.0, 2, 3.0),
+        (null, null, 3, 2.0),
+        (null, null, 4, 1.0)
+      ))
+
+    checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
+      ShuffledHashOuterJoin(leftKeys, rightKeys, FullOuter, condition, left, right),
+      Seq(
+        (1, 2.0, null, null),
+        (2, 1.0, 2, 3.0),
+        (3, 3.0, null, null),
+        (null, null, 3, 2.0),
+        (null, null, 4, 1.0)
+      ))
+  }
+
+  test("broadcast hash outer join") {
+    checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
+      BroadcastHashOuterJoin(leftKeys, rightKeys, LeftOuter, condition, left, right),
+      Seq(
+        (1, 2.0, null, null),
+        (2, 1.0, 2, 3.0),
+        (3, 3.0, null, null)
+      ))
+
+    checkAnswer(left, right, (left: SparkPlan, right: SparkPlan) =>
+      BroadcastHashOuterJoin(leftKeys, rightKeys, RightOuter, condition, left, right),
+      Seq(
+        (2, 1.0, 2, 3.0),
+        (null, null, 3, 2.0),
+        (null, null, 4, 1.0)
+      ))
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org