You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/03/16 06:17:09 UTC

spark git commit: [SPARK-13917] [SQL] generate broadcast semi join

Repository: spark
Updated Branches:
  refs/heads/master 52b6a899b -> 421f6c20e


[SPARK-13917] [SQL] generate broadcast semi join

## What changes were proposed in this pull request?

This PR brings codegen support for broadcast left-semi join.

## How was this patch tested?

Existing tests. Added benchmark, the result show 7X speedup.

Author: Davies Liu <da...@databricks.com>

Closes #11742 from davies/gen_semi.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/421f6c20
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/421f6c20
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/421f6c20

Branch: refs/heads/master
Commit: 421f6c20e85b32f6462d37dad6a62dec2d46ed88
Parents: 52b6a89
Author: Davies Liu <da...@databricks.com>
Authored: Tue Mar 15 22:17:04 2016 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Tue Mar 15 22:17:04 2016 -0700

----------------------------------------------------------------------
 .../spark/sql/execution/SparkStrategies.scala   |  4 +-
 .../sql/execution/joins/BroadcastHashJoin.scala | 81 ++++++++++++++++++--
 .../joins/BroadcastLeftSemiJoinHash.scala       | 57 --------------
 .../spark/sql/execution/joins/HashJoin.scala    | 23 +++++-
 .../sql/execution/joins/HashSemiJoin.scala      | 61 ---------------
 .../sql/execution/joins/LeftSemiJoinHash.scala  |  8 +-
 .../scala/org/apache/spark/sql/JoinSuite.scala  |  4 +-
 .../execution/BenchmarkWholeStageCodegen.scala  | 14 +++-
 .../execution/joins/BroadcastJoinSuite.scala    |  2 +-
 .../sql/execution/joins/SemiJoinSuite.scala     |  5 +-
 .../apache/spark/sql/hive/StatisticsSuite.scala |  4 +-
 11 files changed, 124 insertions(+), 139 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/421f6c20/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 7fc6a82..121b6d9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -65,8 +65,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
       case ExtractEquiJoinKeys(
              LeftSemi, leftKeys, rightKeys, condition, left, CanBroadcast(right)) =>
-        joins.BroadcastLeftSemiJoinHash(
-          leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil
+        Seq(joins.BroadcastHashJoin(
+          leftKeys, rightKeys, LeftSemi, BuildRight, condition, planLater(left), planLater(right)))
       // Find left semi joins where at least some predicates can be evaluated by matching join keys
       case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
         joins.LeftSemiJoinHash(

http://git-wip-us.apache.org/repos/asf/spark/blob/421f6c20/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 4c8f808..f84ed41 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -23,7 +23,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
-import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter, RightOuter}
+import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution}
 import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan}
 import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -92,6 +92,9 @@ case class BroadcastHashJoin(
             rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows)
           }
 
+        case LeftSemi =>
+          hashSemiJoin(streamedIter, hashTable, numOutputRows)
+
         case x =>
           throw new IllegalArgumentException(
             s"BroadcastHashJoin should not take $x as the JoinType")
@@ -108,11 +111,13 @@ case class BroadcastHashJoin(
   }
 
   override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
-    if (joinType == Inner) {
-      codegenInner(ctx, input)
-    } else {
-      // LeftOuter and RightOuter
-      codegenOuter(ctx, input)
+    joinType match {
+      case Inner => codegenInner(ctx, input)
+      case LeftOuter | RightOuter => codegenOuter(ctx, input)
+      case LeftSemi => codegenSemi(ctx, input)
+      case x =>
+        throw new IllegalArgumentException(
+          s"BroadcastHashJoin should not take $x as the JoinType")
     }
   }
 
@@ -322,4 +327,68 @@ case class BroadcastHashJoin(
        """.stripMargin
     }
   }
+
+  /**
+   * Generates the code for left semi join.
+   */
+  private def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+    val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
+    val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+    val matched = ctx.freshName("matched")
+    val buildVars = genBuildSideVars(ctx, matched)
+    val numOutput = metricTerm(ctx, "numOutputRows")
+
+    val checkCondition = if (condition.isDefined) {
+      val expr = condition.get
+      // evaluate the variables from build side that used by condition
+      val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references)
+      // filter the output via condition
+      ctx.currentVars = input ++ buildVars
+      val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).gen(ctx)
+      s"""
+         |$eval
+         |${ev.code}
+         |if (${ev.isNull} || !${ev.value}) continue;
+       """.stripMargin
+    } else {
+      ""
+    }
+
+    if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
+      s"""
+         |// generate join key for stream side
+         |${keyEv.code}
+         |// find matches from HashedRelation
+         |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
+         |if ($matched == null) continue;
+         |$checkCondition
+         |$numOutput.add(1);
+         |${consume(ctx, input)}
+       """.stripMargin
+    } else {
+      val matches = ctx.freshName("matches")
+      val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
+      val i = ctx.freshName("i")
+      val size = ctx.freshName("size")
+      val found = ctx.freshName("found")
+      s"""
+         |// generate join key for stream side
+         |${keyEv.code}
+         |// find matches from HashRelation
+         |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
+         |if ($matches == null) continue;
+         |int $size = $matches.size();
+         |boolean $found = false;
+         |for (int $i = 0; $i < $size; $i++) {
+         |  UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
+         |  $checkCondition
+         |  $found = true;
+         |  break;
+         |}
+         |if (!$found) continue;
+         |$numOutput.add(1);
+         |${consume(ctx, input)}
+       """.stripMargin
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/421f6c20/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
deleted file mode 100644
index d3bcfad..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
+++ /dev/null
@@ -1,57 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.joins
-
-import org.apache.spark.TaskContext
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution}
-import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
-import org.apache.spark.sql.execution.metric.SQLMetrics
-
-/**
- * Build the right table's join keys into a HashedRelation, and iteratively go through the left
- * table, to find if the join keys are in the HashedRelation.
- */
-case class BroadcastLeftSemiJoinHash(
-    leftKeys: Seq[Expression],
-    rightKeys: Seq[Expression],
-    left: SparkPlan,
-    right: SparkPlan,
-    condition: Option[Expression]) extends BinaryNode with HashSemiJoin {
-
-  override private[sql] lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
-
-  override def requiredChildDistribution: Seq[Distribution] = {
-    val mode = HashedRelationBroadcastMode(canJoinKeyFitWithinLong = false, rightKeys, right.output)
-    UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil
-  }
-
-  protected override def doExecute(): RDD[InternalRow] = {
-    val numOutputRows = longMetric("numOutputRows")
-
-    val broadcastedRelation = right.executeBroadcast[HashedRelation]()
-    left.execute().mapPartitionsInternal { streamIter =>
-      val hashedRelation = broadcastedRelation.value
-      TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)
-      hashSemiJoin(streamIter, hashedRelation, numOutputRows)
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/421f6c20/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 2fe9c06..5f42d07 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -46,8 +46,8 @@ trait HashJoin {
         left.output ++ right.output.map(_.withNullability(true))
       case RightOuter =>
         left.output.map(_.withNullability(true)) ++ right.output
-      case FullOuter =>
-        left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
+      case LeftSemi =>
+        left.output
       case x =>
         throw new IllegalArgumentException(s"HashJoin should not take $x as the JoinType")
     }
@@ -104,7 +104,7 @@ trait HashJoin {
     keyExpr :: Nil
   }
 
-  protected val canJoinKeyFitWithinLong: Boolean = {
+  protected lazy val canJoinKeyFitWithinLong: Boolean = {
     val sameTypes = buildKeys.map(_.dataType) == streamedKeys.map(_.dataType)
     val key = rewriteKeyExpr(buildKeys)
     sameTypes && key.length == 1 && key.head.dataType.isInstanceOf[LongType]
@@ -258,4 +258,21 @@ trait HashJoin {
     }
     ret.iterator
   }
+
+  protected def hashSemiJoin(
+    streamIter: Iterator[InternalRow],
+    hashedRelation: HashedRelation,
+    numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
+    val joinKeys = streamSideKeyGenerator
+    val joinedRow = new JoinedRow
+    streamIter.filter { current =>
+      val key = joinKeys(current)
+      lazy val rowBuffer = hashedRelation.get(key)
+      val r = !key.anyNull && rowBuffer != null && (condition.isEmpty || rowBuffer.exists {
+        (row: InternalRow) => boundCondition(joinedRow(current, row))
+      })
+      if (r) numOutputRows += 1
+      r
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/421f6c20/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
deleted file mode 100644
index 813ec02..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
+++ /dev/null
@@ -1,61 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.joins
-
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.metric.LongSQLMetric
-
-
-trait HashSemiJoin {
-  self: SparkPlan =>
-  val leftKeys: Seq[Expression]
-  val rightKeys: Seq[Expression]
-  val left: SparkPlan
-  val right: SparkPlan
-  val condition: Option[Expression]
-
-  override def output: Seq[Attribute] = left.output
-
-  protected def leftKeyGenerator: Projection =
-    UnsafeProjection.create(leftKeys, left.output)
-
-  protected def rightKeyGenerator: Projection =
-    UnsafeProjection.create(rightKeys, right.output)
-
-  @transient private lazy val boundCondition =
-    newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
-
-  protected def hashSemiJoin(
-      streamIter: Iterator[InternalRow],
-      hashedRelation: HashedRelation,
-      numOutputRows: LongSQLMetric): Iterator[InternalRow] = {
-    val joinKeys = leftKeyGenerator
-    val joinedRow = new JoinedRow
-    streamIter.filter { current =>
-      val key = joinKeys(current)
-      lazy val rowBuffer = hashedRelation.get(key)
-      val r = !key.anyNull && rowBuffer != null && (condition.isEmpty || rowBuffer.exists {
-        (row: InternalRow) => boundCondition(joinedRow(current, row))
-      })
-      if (r) numOutputRows += 1
-      r
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/421f6c20/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
index 14389e4..fa549b4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.joins
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.LeftSemi
 import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning}
 import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
 import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -33,7 +34,10 @@ case class LeftSemiJoinHash(
     rightKeys: Seq[Expression],
     left: SparkPlan,
     right: SparkPlan,
-    condition: Option[Expression]) extends BinaryNode with HashSemiJoin {
+    condition: Option[Expression]) extends BinaryNode with HashJoin {
+
+  override val joinType = LeftSemi
+  override val buildSide = BuildRight
 
   override private[sql] lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
@@ -47,7 +51,7 @@ case class LeftSemiJoinHash(
     val numOutputRows = longMetric("numOutputRows")
 
     right.execute().zipPartitions(left.execute()) { (buildIter, streamIter) =>
-      val hashRelation = HashedRelation(buildIter.map(_.copy()), rightKeyGenerator)
+      val hashRelation = HashedRelation(buildIter.map(_.copy()), buildSideKeyGenerator)
       hashSemiJoin(streamIter, hashRelation, numOutputRows)
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/421f6c20/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 580e8d8..4191991 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -49,7 +49,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
       case j: BroadcastHashJoin => j
       case j: CartesianProduct => j
       case j: BroadcastNestedLoopJoin => j
-      case j: BroadcastLeftSemiJoinHash => j
+      case j: BroadcastHashJoin => j
       case j: SortMergeJoin => j
     }
 
@@ -427,7 +427,7 @@ class JoinSuite extends QueryTest with SharedSQLContext {
     withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") {
       Seq(
         ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
-          classOf[BroadcastLeftSemiJoinHash])
+          classOf[BroadcastHashJoin])
       ).foreach {
         case (query, joinClass) => assertJoin(query, joinClass)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/421f6c20/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 9f33e4a..cb67264 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -38,7 +38,7 @@ import org.apache.spark.util.Benchmark
 class BenchmarkWholeStageCodegen extends SparkFunSuite {
   lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark")
     .set("spark.sql.shuffle.partitions", "1")
-    .set("spark.sql.autoBroadcastJoinThreshold", "0")
+    .set("spark.sql.autoBroadcastJoinThreshold", "1")
   lazy val sc = SparkContext.getOrCreate(conf)
   lazy val sqlContext = SQLContext.getOrCreate(sc)
 
@@ -200,6 +200,18 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
     outer join w long codegen=false        15280 / 16497          6.9         145.7       1.0X
     outer join w long codegen=true            769 /  796        136.3           7.3      19.9X
       */
+
+    runBenchmark("semi join w long", N) {
+      sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k"), "leftsemi").count()
+    }
+
+    /**
+    Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+    semi join w long:                   Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    -------------------------------------------------------------------------------------------
+    semi join w long codegen=false           5804 / 5969         18.1          55.3       1.0X
+    semi join w long codegen=true             814 /  934        128.8           7.8       7.1X
+     */
   }
 
   ignore("sort merge join") {

http://git-wip-us.apache.org/repos/asf/spark/blob/421f6c20/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index 6d5b777..babe7ef 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -79,7 +79,7 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
   }
 
   test("unsafe broadcast left semi join updates peak execution memory") {
-    testBroadcastJoin[BroadcastLeftSemiJoinHash]("unsafe broadcast left semi join", "leftsemi")
+    testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast left semi join", "leftsemi")
   }
 
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/421f6c20/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
index d8c9564..5eb6a74 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
@@ -84,11 +84,12 @@ class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
       }
     }
 
-    test(s"$testName using BroadcastLeftSemiJoinHash") {
+    test(s"$testName using BroadcastHashJoin") {
       extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
         withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
           checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
-            BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
+            BroadcastHashJoin(
+              leftKeys, rightKeys, LeftSemi, BuildRight, boundCondition, left, right),
             expectedAnswer.map(Row.fromTuple),
             sortAnswers = true)
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/421f6c20/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
index 1d8c293..1468be4 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
@@ -212,7 +212,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton {
     // Using `sparkPlan` because for relevant patterns in HashJoin to be
     // matched, other strategies need to be applied.
     var bhj = df.queryExecution.sparkPlan.collect {
-      case j: BroadcastLeftSemiJoinHash => j
+      case j: BroadcastHashJoin => j
     }
     assert(bhj.size === 1,
       s"actual query plans do not contain broadcast join: ${df.queryExecution}")
@@ -225,7 +225,7 @@ class StatisticsSuite extends QueryTest with TestHiveSingleton {
       sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key}=-1")
       df = sql(leftSemiJoinQuery)
       bhj = df.queryExecution.sparkPlan.collect {
-        case j: BroadcastLeftSemiJoinHash => j
+        case j: BroadcastHashJoin => j
       }
       assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off")
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org