You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2017/11/30 23:36:31 UTC

spark git commit: [SPARK-22489][SQL] Shouldn't change broadcast join buildSide if user clearly specified

Repository: spark
Updated Branches:
  refs/heads/master 6ac57fd0d -> bcceab649


[SPARK-22489][SQL] Shouldn't change broadcast join buildSide if user clearly specified

## What changes were proposed in this pull request?

How to reproduce:
```scala
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec

spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("table1")
spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value").createTempView("table2")

val bl = sql("SELECT /*+ MAPJOIN(t1) */ * FROM table1 t1 JOIN table2 t2 ON t1.key = t2.key").queryExecution.executedPlan

println(bl.children.head.asInstanceOf[BroadcastHashJoinExec].buildSide)
```
The result is `BuildRight`, but should be `BuildLeft`. This PR fix this issue.
## How was this patch tested?

unit tests

Author: Yuming Wang <wg...@gmail.com>

Closes #19714 from wangyum/SPARK-22489.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bcceab64
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bcceab64
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bcceab64

Branch: refs/heads/master
Commit: bcceab649510a45f4c4b8e44b157c9987adff6f4
Parents: 6ac57fd
Author: Yuming Wang <wg...@gmail.com>
Authored: Thu Nov 30 15:36:26 2017 -0800
Committer: gatorsmile <ga...@gmail.com>
Committed: Thu Nov 30 15:36:26 2017 -0800

----------------------------------------------------------------------
 docs/sql-programming-guide.md                   | 58 ++++++++++++++++
 .../spark/sql/execution/SparkStrategies.scala   | 67 ++++++++++++++-----
 .../execution/joins/BroadcastJoinSuite.scala    | 69 +++++++++++++++++++-
 3 files changed, 177 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bcceab64/docs/sql-programming-guide.md
----------------------------------------------------------------------
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 983770d..a1b9c3b 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -1492,6 +1492,64 @@ that these options will be deprecated in future release as more optimizations ar
   </tr>
 </table>
 
+## Broadcast Hint for SQL Queries
+
+The `BROADCAST` hint guides Spark to broadcast each specified table when joining them with another table or view.
+When Spark deciding the join methods, the broadcast hash join (i.e., BHJ) is preferred, 
+even if the statistics is above the configuration `spark.sql.autoBroadcastJoinThreshold`.
+When both sides of a join are specified, Spark broadcasts the one having the lower statistics.
+Note Spark does not guarantee BHJ is always chosen, since not all cases (e.g. full outer join) 
+support BHJ. When the broadcast nested loop join is selected, we still respect the hint.
+
+<div class="codetabs">
+
+<div data-lang="scala"  markdown="1">
+
+{% highlight scala %}
+import org.apache.spark.sql.functions.broadcast
+broadcast(spark.table("src")).join(spark.table("records"), "key").show()
+{% endhighlight %}
+
+</div>
+
+<div data-lang="java"  markdown="1">
+
+{% highlight java %}
+import static org.apache.spark.sql.functions.broadcast;
+broadcast(spark.table("src")).join(spark.table("records"), "key").show();
+{% endhighlight %}
+
+</div>
+
+<div data-lang="python"  markdown="1">
+
+{% highlight python %}
+from pyspark.sql.functions import broadcast
+broadcast(spark.table("src")).join(spark.table("records"), "key").show()
+{% endhighlight %}
+
+</div>
+
+<div data-lang="r"  markdown="1">
+
+{% highlight r %}
+src <- sql("SELECT * FROM src")
+records <- sql("SELECT * FROM records")
+head(join(broadcast(src), records, src$key == records$key))
+{% endhighlight %}
+
+</div>
+
+<div data-lang="sql"  markdown="1">
+
+{% highlight sql %}
+-- We accept BROADCAST, BROADCASTJOIN and MAPJOIN for broadcast hint
+SELECT /*+ BROADCAST(r) */ * FROM records r JOIN src s ON r.key = s.key
+{% endhighlight %}
+
+</div>
+</div>
+
 # Distributed SQL Engine
 
 Spark SQL can also act as a distributed query engine using its JDBC/ODBC or command-line interface.

http://git-wip-us.apache.org/repos/asf/spark/blob/bcceab64/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 19b858f..1fe3cb1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
 import org.apache.spark.sql.execution.command._
 import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
-import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight}
+import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.StreamingQuery
@@ -91,12 +91,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
    * predicates can be evaluated by matching join keys. If found,  Join implementations are chosen
    * with the following precedence:
    *
-   * - Broadcast: if one side of the join has an estimated physical size that is smaller than the
-   *     user-configurable [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold
-   *     or if that side has an explicit broadcast hint (e.g. the user applied the
-   *     [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame), then that side
-   *     of the join will be broadcasted and the other side will be streamed, with no shuffling
-   *     performed. If both sides of the join are eligible to be broadcasted then the
+   * - Broadcast: We prefer to broadcast the join side with an explicit broadcast hint(e.g. the
+   *     user applied the [[org.apache.spark.sql.functions.broadcast()]] function to a DataFrame).
+   *     If both sides have the broadcast hint, we prefer to broadcast the side with a smaller
+   *     estimated physical size. If neither one of the sides has the broadcast hint,
+   *     we only broadcast the join side if its estimated physical size that is smaller than
+   *     the user-configurable [[SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]] threshold.
    * - Shuffle hash join: if the average size of a single partition is small enough to build a hash
    *     table.
    * - Sort merge: if the matching join keys are sortable.
@@ -112,9 +112,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
      * Matches a plan whose output should be small enough to be used in broadcast join.
      */
     private def canBroadcast(plan: LogicalPlan): Boolean = {
-      plan.stats.hints.broadcast ||
-        (plan.stats.sizeInBytes >= 0 &&
-          plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold)
+      plan.stats.sizeInBytes >= 0 && plan.stats.sizeInBytes <= conf.autoBroadcastJoinThreshold
     }
 
     /**
@@ -149,11 +147,46 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
       case _ => false
     }
 
+    private def broadcastSide(
+        canBuildLeft: Boolean,
+        canBuildRight: Boolean,
+        left: LogicalPlan,
+        right: LogicalPlan): BuildSide = {
+
+      def smallerSide =
+        if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft
+
+      val buildRight = canBuildRight && right.stats.hints.broadcast
+      val buildLeft = canBuildLeft && left.stats.hints.broadcast
+
+      if (buildRight && buildLeft) {
+        // Broadcast smaller side base on its estimated physical size
+        // if both sides have broadcast hint
+        smallerSide
+      } else if (buildRight) {
+        BuildRight
+      } else if (buildLeft) {
+        BuildLeft
+      } else if (canBuildRight && canBuildLeft) {
+        // for the last default broadcast nested loop join
+        smallerSide
+      } else {
+        throw new AnalysisException("Can not decide which side to broadcast for this join")
+      }
+    }
+
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
 
       // --- BroadcastHashJoin --------------------------------------------------------------------
 
       case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
+        if (canBuildRight(joinType) && right.stats.hints.broadcast) ||
+          (canBuildLeft(joinType) && left.stats.hints.broadcast) =>
+        val buildSide = broadcastSide(canBuildLeft(joinType), canBuildRight(joinType), left, right)
+        Seq(joins.BroadcastHashJoinExec(
+          leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right)))
+
+      case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right)
         if canBuildRight(joinType) && canBroadcast(right) =>
         Seq(joins.BroadcastHashJoinExec(
           leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right)))
@@ -190,6 +223,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
 
       // Pick BroadcastNestedLoopJoin if one side could be broadcasted
       case j @ logical.Join(left, right, joinType, condition)
+        if (canBuildRight(joinType) && right.stats.hints.broadcast) ||
+          (canBuildLeft(joinType) && left.stats.hints.broadcast) =>
+        val buildSide = broadcastSide(canBuildLeft(joinType), canBuildRight(joinType), left, right)
+        joins.BroadcastNestedLoopJoinExec(
+          planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
+
+      case j @ logical.Join(left, right, joinType, condition)
           if canBuildRight(joinType) && canBroadcast(right) =>
         joins.BroadcastNestedLoopJoinExec(
           planLater(left), planLater(right), BuildRight, joinType, condition) :: Nil
@@ -203,12 +243,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
         joins.CartesianProductExec(planLater(left), planLater(right), condition) :: Nil
 
       case logical.Join(left, right, joinType, condition) =>
-        val buildSide =
-          if (right.stats.sizeInBytes <= left.stats.sizeInBytes) {
-            BuildRight
-          } else {
-            BuildLeft
-          }
+        val buildSide = broadcastSide(canBuildLeft = true, canBuildRight = true, left, right)
         // This join could be very slow or OOM
         joins.BroadcastNestedLoopJoinExec(
           planLater(left), planLater(right), buildSide, joinType, condition) :: Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/bcceab64/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index a0fad86..67e2cdc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -22,7 +22,7 @@ import scala.reflect.ClassTag
 import org.apache.spark.AccumulatorSuite
 import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession}
 import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft}
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan, WholeStageCodegenExec}
 import org.apache.spark.sql.execution.exchange.EnsureRequirements
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
@@ -223,4 +223,71 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
     assert(HashJoin.rewriteKeyExpr(l :: ss :: Nil) === l :: ss :: Nil)
     assert(HashJoin.rewriteKeyExpr(i :: ss :: Nil) === i :: ss :: Nil)
   }
+
+  test("Shouldn't change broadcast join buildSide if user clearly specified") {
+    def assertJoinBuildSide(sqlStr: String, joinMethod: String, buildSide: BuildSide): Any = {
+      val executedPlan = sql(sqlStr).queryExecution.executedPlan
+      executedPlan match {
+        case b: BroadcastNestedLoopJoinExec =>
+          assert(b.getClass.getSimpleName === joinMethod)
+          assert(b.buildSide === buildSide)
+        case w: WholeStageCodegenExec =>
+          assert(w.children.head.getClass.getSimpleName === joinMethod)
+          assert(w.children.head.asInstanceOf[BroadcastHashJoinExec].buildSide === buildSide)
+      }
+    }
+
+    withTempView("t1", "t2") {
+      spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value").createTempView("t1")
+      spark.createDataFrame(Seq((1, "1"), (2, "12.3"), (2, "123"))).toDF("key", "value")
+        .createTempView("t2")
+
+      val t1Size = spark.table("t1").queryExecution.analyzed.children.head.stats.sizeInBytes
+      val t2Size = spark.table("t2").queryExecution.analyzed.children.head.stats.sizeInBytes
+      assert(t1Size < t2Size)
+
+      val bh = BroadcastHashJoinExec.toString
+      val bl = BroadcastNestedLoopJoinExec.toString
+
+      // INNER JOIN && t1Size < t2Size => BuildLeft
+      assertJoinBuildSide(
+        "SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 JOIN t2 ON t1.key = t2.key", bh, BuildLeft)
+      // LEFT JOIN => BuildRight
+      assertJoinBuildSide(
+        "SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 LEFT JOIN t2 ON t1.key = t2.key", bh, BuildRight)
+      // RIGHT JOIN => BuildLeft
+      assertJoinBuildSide(
+        "SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 RIGHT JOIN t2 ON t1.key = t2.key", bh, BuildLeft)
+      // INNER JOIN && broadcast(t1) => BuildLeft
+      assertJoinBuildSide(
+        "SELECT /*+ MAPJOIN(t1) */ * FROM t1 JOIN t2 ON t1.key = t2.key", bh, BuildLeft)
+      // INNER JOIN && broadcast(t2) => BuildRight
+      assertJoinBuildSide(
+        "SELECT /*+ MAPJOIN(t2) */ * FROM t1 JOIN t2 ON t1.key = t2.key", bh, BuildRight)
+
+
+      withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0",
+        SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
+        // INNER JOIN && t1Size < t2Size => BuildLeft
+        assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 JOIN t2", bl, BuildLeft)
+        // FULL JOIN && t1Size < t2Size => BuildLeft
+        assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 FULL JOIN t2", bl, BuildLeft)
+        // LEFT JOIN => BuildRight
+        assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 LEFT JOIN t2", bl, BuildRight)
+        // RIGHT JOIN => BuildLeft
+        assertJoinBuildSide("SELECT /*+ MAPJOIN(t1, t2) */ * FROM t1 RIGHT JOIN t2", bl, BuildLeft)
+        // INNER JOIN && broadcast(t1) => BuildLeft
+        assertJoinBuildSide("SELECT /*+ MAPJOIN(t1) */ * FROM t1 JOIN t2", bl, BuildLeft)
+        // INNER JOIN && broadcast(t2) => BuildRight
+        assertJoinBuildSide("SELECT /*+ MAPJOIN(t2) */ * FROM t1 JOIN t2", bl, BuildRight)
+        // FULL OUTER && broadcast(t1) => BuildLeft
+        assertJoinBuildSide("SELECT /*+ MAPJOIN(t1) */ * FROM t1 FULL OUTER JOIN t2", bl, BuildLeft)
+        // FULL OUTER && broadcast(t2) => BuildRight
+        assertJoinBuildSide(
+          "SELECT /*+ MAPJOIN(t2) */ * FROM t1 FULL OUTER JOIN t2", bl, BuildRight)
+        // FULL OUTER && t1Size < t2Size => BuildLeft
+        assertJoinBuildSide("SELECT * FROM t1 FULL OUTER JOIN t2", bl, BuildLeft)
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org