You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/02/26 18:58:12 UTC

spark git commit: [SPARK-12313] [SQL] improve performance of BroadcastNestedLoopJoin

Repository: spark
Updated Branches:
  refs/heads/master 727e78014 -> 6df1e55a6


[SPARK-12313] [SQL] improve performance of BroadcastNestedLoopJoin

## What changes were proposed in this pull request?

Currently, BroadcastNestedLoopJoin is implemented for worst case, it's too slow, very easy to hang forever. This PR will create fast path for some joinType and buildSide, also improve the worst case (will use much less memory than before).

Before this PR, one task requires O(N*K) + O(K) in worst cases, N is number of rows from one partition of streamed table, it could hang the job (because of GC).

In order to workaround this for InnerJoin, we have to disable auto-broadcast, switch to CartesianProduct: This could be workaround for InnerJoin, see https://forums.databricks.com/questions/6747/how-do-i-get-a-cartesian-product-of-a-huge-dataset.html

In this PR, we will have fast path for these joins :

 InnerJoin with BuildLeft or BuildRight
 LeftOuterJoin with BuildRight
 RightOuterJoin with BuildLeft
 LeftSemi with BuildRight

These fast paths are all stream based (take one pass on streamed table), required O(1) memory.

All other join types and build types will take two pass on streamed table, one pass to find the matched rows that includes streamed part, which require O(1) memory, another pass to find the rows from build table that does not have a matched row from streamed table, which required O(K) memory, K is the number rows from build side, one bit per row, should be much smaller than the memory for broadcast. The following join types work in this way:

LeftOuterJoin with BuildLeft
RightOuterJoin with BuildRight
FullOuterJoin with BuildLeft or BuildRight
LeftSemi with BuildLeft

This PR also added tests for all the join types for BroadcastNestedLoopJoin.

After this PR, for InnerJoin with one small table, BroadcastNestedLoopJoin should be faster than CartesianProduct, we don't need that workaround anymore.

## How was the this patch tested?

Added unit tests.

Author: Davies Liu <da...@databricks.com>

Closes #11328 from davies/nested_loop.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6df1e55a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6df1e55a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6df1e55a

Branch: refs/heads/master
Commit: 6df1e55a6594ae4bc7882f44af8d230aad9489b4
Parents: 727e780
Author: Davies Liu <da...@databricks.com>
Authored: Fri Feb 26 09:58:05 2016 -0800
Committer: Davies Liu <da...@gmail.com>
Committed: Fri Feb 26 09:58:05 2016 -0800

----------------------------------------------------------------------
 .../catalyst/plans/physical/broadcastMode.scala |   1 +
 .../spark/sql/execution/SparkStrategies.scala   |  14 +-
 .../joins/BroadcastNestedLoopJoin.scala         | 295 ++++++++++++++-----
 .../scala/org/apache/spark/sql/JoinSuite.scala  |  11 +-
 .../sql/execution/joins/InnerJoinSuite.scala    |  27 ++
 .../sql/execution/joins/OuterJoinSuite.scala    |  18 ++
 .../sql/execution/joins/SemiJoinSuite.scala     |  20 +-
 7 files changed, 295 insertions(+), 91 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6df1e55a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala
index c646dcf..e01f69f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala
@@ -31,5 +31,6 @@ trait BroadcastMode {
  * IdentityBroadcastMode requires that rows are broadcasted in their original form.
  */
 case object IdentityBroadcastMode extends BroadcastMode {
+  // TODO: pack the UnsafeRows into single bytes array.
   override def transform(rows: Array[InternalRow]): Array[InternalRow] = rows
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/6df1e55a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 5fdf38c..dd8c96d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -253,22 +253,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
 
   object BroadcastNestedLoop extends Strategy {
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
-      case logical.Join(
-             CanBroadcast(left), right, joinType, condition) if joinType != LeftSemi =>
+      case j @ logical.Join(CanBroadcast(left), right, Inner | RightOuter, condition) =>
         execution.joins.BroadcastNestedLoopJoin(
-          planLater(left), planLater(right), joins.BuildLeft, joinType, condition) :: Nil
-      case logical.Join(
-             left, CanBroadcast(right), joinType, condition) if joinType != LeftSemi =>
+          planLater(left), planLater(right), joins.BuildLeft, j.joinType, condition) :: Nil
+      case j @ logical.Join(left, CanBroadcast(right), Inner | LeftOuter | LeftSemi, condition) =>
         execution.joins.BroadcastNestedLoopJoin(
-          planLater(left), planLater(right), joins.BuildRight, joinType, condition) :: Nil
+          planLater(left), planLater(right), joins.BuildRight, j.joinType, condition) :: Nil
       case _ => Nil
     }
   }
 
   object CartesianProduct extends Strategy {
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
-      // TODO CartesianProduct doesn't support the Left Semi Join
-      case logical.Join(left, right, joinType, None) if joinType != LeftSemi =>
+      case logical.Join(left, right, Inner, None) =>
         execution.joins.CartesianProduct(planLater(left), planLater(right)) :: Nil
       case logical.Join(left, right, Inner, Some(condition)) =>
         execution.Filter(condition,
@@ -286,6 +283,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
           } else {
             joins.BuildLeft
           }
+        // This join could be very slow or even hang forever
         joins.BroadcastNestedLoopJoin(
           planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
       case _ => Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/6df1e55a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
index e8bd7f6..d83486d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.execution.joins
 
+import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
@@ -26,7 +27,6 @@ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
 import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.util.collection.{BitSet, CompactBuffer}
 
-
 case class BroadcastNestedLoopJoin(
     left: SparkPlan,
     right: SparkPlan,
@@ -51,125 +51,266 @@ case class BroadcastNestedLoopJoin(
   }
 
   private[this] def genResultProjection: InternalRow => InternalRow = {
-    UnsafeProjection.create(schema)
+    if (joinType == LeftSemi) {
+      UnsafeProjection.create(output, output)
+    } else {
+      // Always put the stream side on left to simplify implementation
+      UnsafeProjection.create(output, streamed.output ++ broadcast.output)
+    }
   }
 
   override def outputPartitioning: Partitioning = streamed.outputPartitioning
 
   override def output: Seq[Attribute] = {
     joinType match {
+      case Inner =>
+        left.output ++ right.output
       case LeftOuter =>
         left.output ++ right.output.map(_.withNullability(true))
       case RightOuter =>
         left.output.map(_.withNullability(true)) ++ right.output
       case FullOuter =>
         left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
-      case Inner =>
-        // TODO we can avoid breaking the lineage, since we union an empty RDD for Inner Join case
-        left.output ++ right.output
-      case x => // TODO support the Left Semi Join
+      case LeftSemi =>
+        left.output
+      case x =>
         throw new IllegalArgumentException(
           s"BroadcastNestedLoopJoin should not take $x as the JoinType")
     }
   }
 
-  @transient private lazy val boundCondition =
-    newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
+  @transient private lazy val boundCondition = {
+    if (condition.isDefined) {
+      newPredicate(condition.get, streamed.output ++ broadcast.output)
+    } else {
+      (r: InternalRow) => true
+    }
+  }
 
-  protected override def doExecute(): RDD[InternalRow] = {
-    val numOutputRows = longMetric("numOutputRows")
+  /**
+   * The implementation for InnerJoin.
+   */
+  private def innerJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
+    streamed.execute().mapPartitionsInternal { streamedIter =>
+      val buildRows = relation.value
+      val joinedRow = new JoinedRow
 
-    val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]()
+      streamedIter.flatMap { streamedRow =>
+        val joinedRows = buildRows.iterator.map(r => joinedRow(streamedRow, r))
+        if (condition.isDefined) {
+          joinedRows.filter(boundCondition)
+        } else {
+          joinedRows
+        }
+      }
+    }
+  }
 
-    /** All rows that either match both-way, or rows from streamed joined with nulls. */
-    val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter =>
-      val relation = broadcastedRelation.value
+  /**
+   * The implementation for these joins:
+   *
+   *   LeftOuter with BuildRight
+   *   RightOuter with BuildLeft
+   */
+  private def outerJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
+    streamed.execute().mapPartitionsInternal { streamedIter =>
+      val buildRows = relation.value
+      val joinedRow = new JoinedRow
+      val nulls = new GenericMutableRow(broadcast.output.size)
+
+      // Returns an iterator to avoid copy the rows.
+      new Iterator[InternalRow] {
+        // current row from stream side
+        private var streamRow: InternalRow = null
+        // have found a match for current row or not
+        private var foundMatch: Boolean = false
+        // the matched result row
+        private var resultRow: InternalRow = null
+        // the next index of buildRows to try
+        private var nextIndex: Int = 0
 
-      val matchedRows = new CompactBuffer[InternalRow]
-      val includedBroadcastTuples = new BitSet(relation.length)
+        private def findNextMatch(): Boolean = {
+          if (streamRow == null) {
+            if (!streamedIter.hasNext) {
+              return false
+            }
+            streamRow = streamedIter.next()
+            nextIndex = 0
+            foundMatch = false
+          }
+          while (nextIndex < buildRows.length) {
+            resultRow = joinedRow(streamRow, buildRows(nextIndex))
+            nextIndex += 1
+            if (boundCondition(resultRow)) {
+              foundMatch = true
+              return true
+            }
+          }
+          if (!foundMatch) {
+            resultRow = joinedRow(streamRow, nulls)
+            streamRow = null
+            true
+          } else {
+            resultRow = null
+            streamRow = null
+            findNextMatch()
+          }
+        }
+
+        override def hasNext(): Boolean = {
+          resultRow != null || findNextMatch()
+        }
+        override def next(): InternalRow = {
+          val r = resultRow
+          resultRow = null
+          r
+        }
+      }
+    }
+  }
+
+  /**
+   * The implementation for these joins:
+   *
+   *   LeftSemi with BuildRight
+   */
+  private def leftSemiJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
+    assert(buildSide == BuildRight)
+    streamed.execute().mapPartitionsInternal { streamedIter =>
+      val buildRows = relation.value
       val joinedRow = new JoinedRow
 
-      val leftNulls = new GenericMutableRow(left.output.size)
-      val rightNulls = new GenericMutableRow(right.output.size)
-      val resultProj = genResultProjection
+      if (condition.isDefined) {
+        streamedIter.filter(l =>
+          buildRows.exists(r => boundCondition(joinedRow(l, r)))
+        )
+      } else {
+        streamedIter.filter(r => !buildRows.isEmpty)
+      }
+    }
+  }
+
+  /**
+   * The implementation for these joins:
+   *
+   *   LeftOuter with BuildLeft
+   *   RightOuter with BuildRight
+   *   FullOuter
+   *   LeftSemi with BuildLeft
+   */
+  private def defaultJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
+    /** All rows that either match both-way, or rows from streamed joined with nulls. */
+    val streamRdd = streamed.execute()
+
+    val matchedBuildRows = streamRdd.mapPartitionsInternal { streamedIter =>
+      val buildRows = relation.value
+      val matched = new BitSet(buildRows.length)
+      val joinedRow = new JoinedRow
 
       streamedIter.foreach { streamedRow =>
         var i = 0
-        var streamRowMatched = false
-
-        while (i < relation.length) {
-          val broadcastedRow = relation(i)
-          buildSide match {
-            case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) =>
-              matchedRows += resultProj(joinedRow(streamedRow, broadcastedRow)).copy()
-              streamRowMatched = true
-              includedBroadcastTuples.set(i)
-            case BuildLeft if boundCondition(joinedRow(broadcastedRow, streamedRow)) =>
-              matchedRows += resultProj(joinedRow(broadcastedRow, streamedRow)).copy()
-              streamRowMatched = true
-              includedBroadcastTuples.set(i)
-            case _ =>
+        while (i < buildRows.length) {
+          if (boundCondition(joinedRow(streamedRow, buildRows(i)))) {
+            matched.set(i)
           }
           i += 1
         }
+      }
+      Seq(matched).toIterator
+    }
 
-        (streamRowMatched, joinType, buildSide) match {
-          case (false, LeftOuter | FullOuter, BuildRight) =>
-            matchedRows += resultProj(joinedRow(streamedRow, rightNulls)).copy()
-          case (false, RightOuter | FullOuter, BuildLeft) =>
-            matchedRows += resultProj(joinedRow(leftNulls, streamedRow)).copy()
-          case _ =>
+    val matchedBroadcastRows = matchedBuildRows.fold(
+      new BitSet(relation.value.length)
+    )(_ | _)
+
+    if (joinType == LeftSemi) {
+      assert(buildSide == BuildLeft)
+      val buf: CompactBuffer[InternalRow] = new CompactBuffer()
+      var i = 0
+      val rel = relation.value
+      while (i < rel.length) {
+        if (matchedBroadcastRows.get(i)) {
+          buf += rel(i).copy()
         }
+        i += 1
       }
-      Iterator((matchedRows, includedBroadcastTuples))
+      return sparkContext.makeRDD(buf.toSeq)
     }
 
-    val includedBroadcastTuples = matchesOrStreamedRowsWithNulls.map(_._2)
-    val allIncludedBroadcastTuples = includedBroadcastTuples.fold(
-      new BitSet(broadcastedRelation.value.size)
-    )(_ | _)
+    val matchedStreamRows = streamRdd.mapPartitionsInternal { streamedIter =>
+      val buildRows = relation.value
+      val joinedRow = new JoinedRow
+      val nulls = new GenericMutableRow(broadcast.output.size)
 
-    val leftNulls = new GenericMutableRow(left.output.size)
-    val rightNulls = new GenericMutableRow(right.output.size)
-    val resultProj = genResultProjection
+      streamedIter.flatMap { streamedRow =>
+        var i = 0
+        var foundMatch = false
+        val matchedRows = new CompactBuffer[InternalRow]
+
+        while (i < buildRows.length) {
+          if (boundCondition(joinedRow(streamedRow, buildRows(i)))) {
+            matchedRows += joinedRow.copy()
+            foundMatch = true
+          }
+          i += 1
+        }
+
+        if (!foundMatch && joinType == FullOuter) {
+          matchedRows += joinedRow(streamedRow, nulls).copy()
+        }
+        matchedRows.iterator
+      }
+    }
 
-    /** Rows from broadcasted joined with nulls. */
-    val broadcastRowsWithNulls: Seq[InternalRow] = {
+    val notMatchedBroadcastRows: Seq[InternalRow] = {
+      val nulls = new GenericMutableRow(streamed.output.size)
       val buf: CompactBuffer[InternalRow] = new CompactBuffer()
       var i = 0
-      val rel = broadcastedRelation.value
-      (joinType, buildSide) match {
-        case (RightOuter | FullOuter, BuildRight) =>
-          val joinedRow = new JoinedRow
-          joinedRow.withLeft(leftNulls)
-          while (i < rel.length) {
-            if (!allIncludedBroadcastTuples.get(i)) {
-              buf += resultProj(joinedRow.withRight(rel(i))).copy()
-            }
-            i += 1
-          }
-        case (LeftOuter | FullOuter, BuildLeft) =>
-          val joinedRow = new JoinedRow
-          joinedRow.withRight(rightNulls)
-          while (i < rel.length) {
-            if (!allIncludedBroadcastTuples.get(i)) {
-              buf += resultProj(joinedRow.withLeft(rel(i))).copy()
-            }
-            i += 1
-          }
-        case _ =>
+      val buildRows = relation.value
+      val joinedRow = new JoinedRow
+      joinedRow.withLeft(nulls)
+      while (i < buildRows.length) {
+        if (!matchedBroadcastRows.get(i)) {
+          buf += joinedRow.withRight(buildRows(i)).copy()
+        }
+        i += 1
       }
       buf.toSeq
     }
 
-    // TODO: Breaks lineage.
     sparkContext.union(
-      matchesOrStreamedRowsWithNulls.flatMap(_._1),
-      sparkContext.makeRDD(broadcastRowsWithNulls)
-    ).map { row =>
-      // `broadcastRowsWithNulls` doesn't run in a job so that we have to track numOutputRows here.
-      numOutputRows += 1
-      row
+      matchedStreamRows,
+      sparkContext.makeRDD(notMatchedBroadcastRows)
+    )
+  }
+
+  protected override def doExecute(): RDD[InternalRow] = {
+    val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]()
+
+    val resultRdd = (joinType, buildSide) match {
+      case (Inner, _) =>
+        innerJoin(broadcastedRelation)
+      case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) =>
+        outerJoin(broadcastedRelation)
+      case (LeftSemi, BuildRight) =>
+        leftSemiJoin(broadcastedRelation)
+      case _ =>
+        /**
+         * LeftOuter with BuildLeft
+         * RightOuter with BuildRight
+         * FullOuter
+         * LeftSemi with BuildLeft
+         */
+        defaultJoin(broadcastedRelation)
+    }
+
+    val numOutputRows = longMetric("numOutputRows")
+    resultRdd.mapPartitionsInternal { iter =>
+      val resultProj = genResultProjection
+      iter.map { r =>
+        numOutputRows += 1
+        resultProj(r)
+      }
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/6df1e55a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 41e27ec..3dab848 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -70,13 +70,14 @@ class JoinSuite extends QueryTest with SharedSQLContext {
         ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]),
         ("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]),
         ("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
-        ("SELECT * FROM testData LEFT JOIN testData2", classOf[CartesianProduct]),
-        ("SELECT * FROM testData RIGHT JOIN testData2", classOf[CartesianProduct]),
-        ("SELECT * FROM testData FULL OUTER JOIN testData2", classOf[CartesianProduct]),
-        ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
+        ("SELECT * FROM testData LEFT JOIN testData2", classOf[BroadcastNestedLoopJoin]),
+        ("SELECT * FROM testData RIGHT JOIN testData2", classOf[BroadcastNestedLoopJoin]),
+        ("SELECT * FROM testData FULL OUTER JOIN testData2", classOf[BroadcastNestedLoopJoin]),
+        ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2",
+          classOf[BroadcastNestedLoopJoin]),
         ("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
         ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2",
-          classOf[CartesianProduct]),
+          classOf[BroadcastNestedLoopJoin]),
         ("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]),
         ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a",
           classOf[CartesianProduct]),

http://git-wip-us.apache.org/repos/asf/spark/blob/6df1e55a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index b748229..7eb1524 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -146,6 +146,33 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
         }
       }
     }
+
+    test(s"$testName using CartesianProduct") {
+      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+        checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+          Filter(condition(), CartesianProduct(left, right)),
+          expectedAnswer.map(Row.fromTuple),
+          sortAnswers = true)
+      }
+    }
+
+    test(s"$testName using BroadcastNestedLoopJoin build left") {
+      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+        checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+          BroadcastNestedLoopJoin(left, right, BuildLeft, Inner, Some(condition())),
+          expectedAnswer.map(Row.fromTuple),
+          sortAnswers = true)
+      }
+    }
+
+    test(s"$testName using BroadcastNestedLoopJoin build right") {
+      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+        checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+          BroadcastNestedLoopJoin(left, right, BuildRight, Inner, Some(condition())),
+          expectedAnswer.map(Row.fromTuple),
+          sortAnswers = true)
+      }
+    }
   }
 
   testInnerJoin(

http://git-wip-us.apache.org/repos/asf/spark/blob/6df1e55a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index 22fe8ca..0d1c29f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -105,6 +105,24 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext {
         }
       }
     }
+
+    test(s"$testName using BroadcastNestedLoopJoin build left") {
+      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+        checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+          BroadcastNestedLoopJoin(left, right, BuildLeft, joinType, Some(condition)),
+          expectedAnswer.map(Row.fromTuple),
+          sortAnswers = true)
+      }
+    }
+
+    test(s"$testName using BroadcastNestedLoopJoin build right") {
+      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+        checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+          BroadcastNestedLoopJoin(left, right, BuildRight, joinType, Some(condition)),
+          expectedAnswer.map(Row.fromTuple),
+          sortAnswers = true)
+      }
+    }
   }
 
   // --- Basic outer joins ------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/spark/blob/6df1e55a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
index 5c98288..355f916 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins
 import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan}
 import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
-import org.apache.spark.sql.catalyst.plans.Inner
+import org.apache.spark.sql.catalyst.plans.{Inner, LeftSemi}
 import org.apache.spark.sql.catalyst.plans.logical.Join
 import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}
 import org.apache.spark.sql.execution.exchange.EnsureRequirements
@@ -103,6 +103,24 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
           sortAnswers = true)
       }
     }
+
+    test(s"$testName using BroadcastNestedLoopJoin build left") {
+      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+        checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+          BroadcastNestedLoopJoin(left, right, BuildLeft, LeftSemi, Some(condition)),
+          expectedAnswer.map(Row.fromTuple),
+          sortAnswers = true)
+      }
+    }
+
+    test(s"$testName using BroadcastNestedLoopJoin build right") {
+      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+        checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+          BroadcastNestedLoopJoin(left, right, BuildRight, LeftSemi, Some(condition)),
+          expectedAnswer.map(Row.fromTuple),
+          sortAnswers = true)
+      }
+    }
   }
 
   testLeftSemiJoin(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org