You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2023/06/08 03:15:19 UTC

[spark] branch master updated: [SPARK-44000][SQL] Add hint to disable broadcasting and replicating one side of join

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d88633ada5e [SPARK-44000][SQL] Add hint to disable broadcasting and replicating one side of join
d88633ada5e is described below

commit d88633ada5eb73e8876acaa2c2a53b9596f2acdd
Author: aokolnychyi <ao...@apple.com>
AuthorDate: Wed Jun 7 20:15:05 2023 -0700

    [SPARK-44000][SQL] Add hint to disable broadcasting and replicating one side of join
    
    ### What changes were proposed in this pull request?
    
    This PR adds a new internal join hint to disable broadcasting and replicating one side of join.
    
    ### Why are the changes needed?
    
    These changes are needed to disable broadcasting and replicating one side of join when it is not permitted, such as the cardinality check in MERGE operations in PR #41448.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    This PR comes with tests. More tests are in #41448.
    
    Closes #41499 from aokolnychyi/spark-44000.
    
    Authored-by: aokolnychyi <ao...@apple.com>
    Signed-off-by: Dongjoon Hyun <do...@apache.org>
---
 .../spark/sql/catalyst/optimizer/joins.scala       | 34 +++++++++++-
 .../spark/sql/catalyst/plans/logical/hints.scala   | 10 ++++
 .../spark/sql/execution/SparkStrategies.scala      | 29 +++++++---
 .../scala/org/apache/spark/sql/JoinSuite.scala     | 64 +++++++++++++++++++++-
 4 files changed, 127 insertions(+), 10 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
index 48b4007a897..8f03b93dce7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
@@ -341,6 +341,16 @@ trait JoinSelectionHelper {
     )
   }
 
+  def getBroadcastNestedLoopJoinBuildSide(hint: JoinHint): Option[BuildSide] = {
+    if (hintToNotBroadcastAndReplicateLeft(hint)) {
+      Some(BuildRight)
+    } else if (hintToNotBroadcastAndReplicateRight(hint)) {
+      Some(BuildLeft)
+    } else {
+      None
+    }
+  }
+
   def getSmallerSide(left: LogicalPlan, right: LogicalPlan): BuildSide = {
     if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft
   }
@@ -413,11 +423,19 @@ trait JoinSelectionHelper {
   }
 
   def hintToNotBroadcastLeft(hint: JoinHint): Boolean = {
-    hint.leftHint.exists(_.strategy.contains(NO_BROADCAST_HASH))
+    hint.leftHint.flatMap(_.strategy).exists {
+      case NO_BROADCAST_HASH => true
+      case NO_BROADCAST_AND_REPLICATION => true
+      case _ => false
+    }
   }
 
   def hintToNotBroadcastRight(hint: JoinHint): Boolean = {
-    hint.rightHint.exists(_.strategy.contains(NO_BROADCAST_HASH))
+    hint.rightHint.flatMap(_.strategy).exists {
+      case NO_BROADCAST_HASH => true
+      case NO_BROADCAST_AND_REPLICATION => true
+      case _ => false
+    }
   }
 
   def hintToShuffleHashJoinLeft(hint: JoinHint): Boolean = {
@@ -454,6 +472,18 @@ trait JoinSelectionHelper {
       hint.rightHint.exists(_.strategy.contains(SHUFFLE_REPLICATE_NL))
   }
 
+  def hintToNotBroadcastAndReplicate(hint: JoinHint): Boolean = {
+    hintToNotBroadcastAndReplicateLeft(hint) || hintToNotBroadcastAndReplicateRight(hint)
+  }
+
+  def hintToNotBroadcastAndReplicateLeft(hint: JoinHint): Boolean = {
+    hint.leftHint.exists(_.strategy.contains(NO_BROADCAST_AND_REPLICATION))
+  }
+
+  def hintToNotBroadcastAndReplicateRight(hint: JoinHint): Boolean = {
+    hint.rightHint.exists(_.strategy.contains(NO_BROADCAST_AND_REPLICATION))
+  }
+
   private def getBuildSide(
       canBuildLeft: Boolean,
       canBuildRight: Boolean,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
index 5dc3eb707f6..b17bab7849b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
@@ -187,6 +187,16 @@ case object PREFER_SHUFFLE_HASH extends JoinStrategyHint {
   override def hintAliases: Set[String] = Set.empty
 }
 
+/**
+ * An internal hint to prohibit broadcasting and replicating one side of a join. This hint is used
+ * by some rules where broadcasting or replicating a particular side of the join is not permitted,
+ * such as the cardinality check in MERGE operations.
+ */
+case object NO_BROADCAST_AND_REPLICATION extends JoinStrategyHint {
+  override def displayName: String = "no_broadcast_and_replication"
+  override def hintAliases: Set[String] = Set.empty
+}
+
 /**
  * The callback for implementing customized strategies of handling hint errors.
  */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 9f256156a82..f29a73ce936 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -272,7 +272,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
         }
 
         def createCartesianProduct() = {
-          if (joinType.isInstanceOf[InnerLike]) {
+          if (joinType.isInstanceOf[InnerLike] && !hintToNotBroadcastAndReplicate(hint)) {
             // `CartesianProductExec` can't implicitly evaluate equal join condition, here we should
             // pass the original condition which includes both equal and non-equal conditions.
             Some(Seq(joins.CartesianProductExec(planLater(left), planLater(right), j.condition)))
@@ -288,7 +288,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
             .orElse(createCartesianProduct())
             .getOrElse {
               // This join could be very slow or OOM
-              val buildSide = getSmallerSide(left, right)
+              val requiredBuildSide = getBroadcastNestedLoopJoinBuildSide(hint)
+              val buildSide = requiredBuildSide.getOrElse(getSmallerSide(left, right))
               Seq(joins.BroadcastNestedLoopJoinExec(
                 planLater(left), planLater(right), buildSide, joinType, j.condition))
             }
@@ -336,7 +337,19 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
           if (canBuildBroadcastLeft(joinType)) BuildLeft else BuildRight
         }
 
-        def createBroadcastNLJoin(buildLeft: Boolean, buildRight: Boolean) = {
+        def createBroadcastNLJoin(onlyLookingAtHint: Boolean) = {
+          val buildLeft = if (onlyLookingAtHint) {
+            hintToBroadcastLeft(hint)
+          } else {
+            canBroadcastBySize(left, conf) && !hintToNotBroadcastAndReplicateLeft(hint)
+          }
+
+          val buildRight = if (onlyLookingAtHint) {
+            hintToBroadcastRight(hint)
+          } else {
+            canBroadcastBySize(right, conf) && !hintToNotBroadcastAndReplicateRight(hint)
+          }
+
           val maybeBuildSide = if (buildLeft && buildRight) {
             Some(desiredBuildSide)
           } else if (buildLeft) {
@@ -354,7 +367,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
         }
 
         def createCartesianProduct() = {
-          if (joinType.isInstanceOf[InnerLike]) {
+          if (joinType.isInstanceOf[InnerLike] && !hintToNotBroadcastAndReplicate(hint)) {
             Some(Seq(joins.CartesianProductExec(planLater(left), planLater(right), condition)))
           } else {
             None
@@ -362,19 +375,21 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
         }
 
         def createJoinWithoutHint() = {
-          createBroadcastNLJoin(canBroadcastBySize(left, conf), canBroadcastBySize(right, conf))
+          createBroadcastNLJoin(false)
             .orElse(createCartesianProduct())
             .getOrElse {
               // This join could be very slow or OOM
+              val requiredBuildSide = getBroadcastNestedLoopJoinBuildSide(hint)
+              val buildSide = requiredBuildSide.getOrElse(desiredBuildSide)
               Seq(joins.BroadcastNestedLoopJoinExec(
-                planLater(left), planLater(right), desiredBuildSide, joinType, condition))
+                planLater(left), planLater(right), buildSide, joinType, condition))
             }
         }
 
         if (hint.isEmpty) {
           createJoinWithoutHint()
         } else {
-          createBroadcastNLJoin(hintToBroadcastLeft(hint), hintToBroadcastRight(hint))
+          createBroadcastNLJoin(true)
             .orElse { if (hintToShuffleReplicateNL(hint)) createCartesianProduct() else None }
             .getOrElse(createJoinWithoutHint())
         }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 12e3f2eb202..60689f96700 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -28,7 +28,8 @@ import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled}
 import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
 import org.apache.spark.sql.catalyst.expressions.{Ascending, GenericRow, SortOrder}
-import org.apache.spark.sql.catalyst.plans.logical.Filter
+import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, HintInfo, Join, JoinHint, NO_BROADCAST_AND_REPLICATION}
 import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec}
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
@@ -92,6 +93,67 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
     operators.head
   }
 
+  test("NO_BROADCAST_AND_REPLICATION hint is respected in cross joins") {
+    withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
+      val noBroadcastAndReplicationHint = JoinHint(
+        leftHint = None,
+        rightHint = Some(HintInfo(Some(NO_BROADCAST_AND_REPLICATION))))
+
+      val join = testData.crossJoin(testData2).queryExecution.optimizedPlan.asInstanceOf[Join]
+      val joinWithHint = join.copy(hint = noBroadcastAndReplicationHint)
+
+      val planned = spark.sessionState.planner.JoinSelection(join)
+      assert(planned.size == 1)
+      assert(planned.head.isInstanceOf[CartesianProductExec])
+
+      val plannedWithHint = spark.sessionState.planner.JoinSelection(joinWithHint)
+      assert(plannedWithHint.size == 1)
+      assert(plannedWithHint.head.isInstanceOf[BroadcastNestedLoopJoinExec])
+      assert(plannedWithHint.head.asInstanceOf[BroadcastNestedLoopJoinExec].buildSide == BuildLeft)
+    }
+  }
+
+  test("NO_BROADCAST_AND_REPLICATION hint disables broadcast hash joins") {
+    sql("CACHE TABLE testData")
+    sql("CACHE TABLE testData2")
+
+    val noBroadcastAndReplicationHint = JoinHint(
+      leftHint = Some(HintInfo(Some(NO_BROADCAST_AND_REPLICATION))),
+      rightHint = Some(HintInfo(Some(NO_BROADCAST_AND_REPLICATION))))
+
+    val ds = sql("SELECT * FROM testData JOIN testData2 ON key = a")
+    val join = ds.queryExecution.optimizedPlan.asInstanceOf[Join]
+    val joinWithHint = join.copy(hint = noBroadcastAndReplicationHint)
+
+    val planned = spark.sessionState.planner.JoinSelection(join)
+    assert(planned.size == 1)
+    assert(planned.head.isInstanceOf[BroadcastHashJoinExec])
+
+    val plannedWithHint = spark.sessionState.planner.JoinSelection(joinWithHint)
+    assert(plannedWithHint.size == 1)
+    assert(plannedWithHint.head.isInstanceOf[SortMergeJoinExec])
+  }
+
+  test("NO_BROADCAST_AND_REPLICATION controls build side in BNLJ") {
+    val noBroadcastAndReplicationHint = JoinHint(
+      leftHint = None,
+      rightHint = Some(HintInfo(Some(NO_BROADCAST_AND_REPLICATION))))
+
+    val ds = testData.join(testData2, $"key" === 1, "left_outer")
+    val join = ds.queryExecution.optimizedPlan.asInstanceOf[Join]
+    val joinWithHint = join.copy(hint = noBroadcastAndReplicationHint)
+
+    val planned = spark.sessionState.planner.JoinSelection(join)
+    assert(planned.size == 1)
+    assert(planned.head.isInstanceOf[BroadcastNestedLoopJoinExec])
+    assert(planned.head.asInstanceOf[BroadcastNestedLoopJoinExec].buildSide == BuildRight)
+
+    val plannedWithHint = spark.sessionState.planner.JoinSelection(joinWithHint)
+    assert(plannedWithHint.size == 1)
+    assert(plannedWithHint.head.isInstanceOf[BroadcastNestedLoopJoinExec])
+    assert(plannedWithHint.head.asInstanceOf[BroadcastNestedLoopJoinExec].buildSide == BuildLeft)
+  }
+
   test("join operator selection") {
     spark.sharedState.cacheManager.clearCache()
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org