You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2023/03/14 13:01:36 UTC

[spark] branch master updated: [SPARK-42101][SQL][FOLLOWUP] Improve TableCacheQueryStage with CoalesceShufflePartitions

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 07f85c4c650 [SPARK-42101][SQL][FOLLOWUP] Improve TableCacheQueryStage with CoalesceShufflePartitions
07f85c4c650 is described below

commit 07f85c4c650b52b8ff2741c71d3d4aa5cfee0820
Author: ulysses-you <ul...@gmail.com>
AuthorDate: Tue Mar 14 21:01:05 2023 +0800

    [SPARK-42101][SQL][FOLLOWUP] Improve TableCacheQueryStage with CoalesceShufflePartitions
    
    ### What changes were proposed in this pull request?
    
    `CoalesceShufflePartitions` should make sure all leaves are `ExchangeQueryStageExec` to avoid collect `TableCacheQueryStage`. As we can not change the partition number of IMR.
    
    Add two tests to make sure `CoalesceShufflePartitions` works well with `TableCacheQueryStage`. Note that, these two tests work without this pr, thanks to `ValidateRequirements` the wrong plan has been reverted.
    
    ### Why are the changes needed?
    
    Avoid potential issue.
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    add test
    
    Closes #40406 from ulysses-you/cache-aqe-followup.
    
    Authored-by: ulysses-you <ul...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../adaptive/CoalesceShufflePartitions.scala       | 14 +++++---
 .../adaptive/AdaptiveQueryExecSuite.scala          | 37 ++++++++++++++++++++++
 2 files changed, 46 insertions(+), 5 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala
index 5c005efb732..dfc7e23c82d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/CoalesceShufflePartitions.scala
@@ -124,7 +124,7 @@ case class CoalesceShufflePartitions(session: SparkSession) extends AQEShuffleRe
   /**
    * Gather all coalesce-able groups such that the shuffle stages in each child of a Union operator
    * are in their independent groups if:
-   * 1) all leaf nodes of this child are shuffle stages; and
+   * 1) all leaf nodes of this child are exchange stages; and
    * 2) all these shuffle stages support coalescing.
    */
   private def collectCoalesceGroups(plan: SparkPlan): Seq[Seq[ShuffleStageInfo]] = plan match {
@@ -132,10 +132,14 @@ case class CoalesceShufflePartitions(session: SparkSession) extends AQEShuffleRe
       Seq(collectShuffleStageInfos(r))
     case unary: UnaryExecNode => collectCoalesceGroups(unary.child)
     case union: UnionExec => union.children.flatMap(collectCoalesceGroups)
-    // If not all leaf nodes are query stages, it's not safe to reduce the number of shuffle
-    // partitions, because we may break the assumption that all children of a spark plan have
-    // same number of output partitions.
-    case p if p.collectLeaves().forall(_.isInstanceOf[QueryStageExec]) =>
+    // If not all leaf nodes are exchange query stages, it's not safe to reduce the number of
+    // shuffle partitions, because we may break the assumption that all children of a spark plan
+    // have same number of output partitions.
+    // Note that, `BroadcastQueryStageExec` is a valid case:
+    // If a join has been optimized from shuffled join to broadcast join, then the one side is
+    // `BroadcastQueryStageExec` and other side is `ShuffleQueryStageExec`. It can coalesce the
+    // shuffle side as we do not expect broadcast exchange has same partition number.
+    case p if p.collectLeaves().forall(_.isInstanceOf[ExchangeQueryStageExec]) =>
       val shuffleStages = collectShuffleStageInfos(p)
       // ShuffleExchanges introduced by repartition do not support partition number change.
       // We change the number of partitions only if all the ShuffleExchanges support it.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
index d2fe588c9a5..8ed31e1968c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
@@ -2765,6 +2765,43 @@ class AdaptiveQueryExecSuite
       checkShuffleAndSort(firstAccess = false)
     }
   }
+
+  test("SPARK-42101: Do not coalesce shuffle partition if other side is TableCacheQueryStage") {
+    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "3",
+        SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
+        SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "true",
+        SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1") {
+      withTempView("v1", "v2") {
+        Seq(1, 2).toDF("c1").repartition(3, $"c1").cache().createOrReplaceTempView("v1")
+        Seq(1, 2).toDF("c2").createOrReplaceTempView("v2")
+
+        val df = spark.sql("SELECT * FROM v1 JOIN v2 ON v1.c1 = v2.c2")
+        df.collect()
+        val finalPlan = df.queryExecution.executedPlan
+        assert(collect(finalPlan) {
+          case q: ShuffleQueryStageExec => q
+        }.size == 1)
+        assert(collect(finalPlan) {
+          case r: AQEShuffleReadExec => r
+        }.isEmpty)
+      }
+    }
+  }
+
+  test("SPARK-42101: Coalesce shuffle partition with union even if exists TableCacheQueryStage") {
+    withSQLConf(SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "true",
+        SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM.key -> "1") {
+      val cached = Seq(1).toDF("c").cache()
+      val df = Seq(2).toDF("c").repartition($"c").unionAll(cached)
+      df.collect()
+      assert(collect(df.queryExecution.executedPlan) {
+        case r @ AQEShuffleReadExec(_: ShuffleQueryStageExec, _) => r
+      }.size == 1)
+      assert(collect(df.queryExecution.executedPlan) {
+        case c: TableCacheQueryStageExec => c
+      }.size == 1)
+    }
+  }
 }
 
 /**


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org