You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2020/11/20 18:05:20 UTC

[spark] branch branch-2.4 updated: [SPARK-33472][SQL][2.4] Adjust RemoveRedundantSorts rule order

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new 6705912  [SPARK-33472][SQL][2.4] Adjust RemoveRedundantSorts rule order
6705912 is described below

commit 67059127a8e8218dc2a8abae6095c1aa299ce175
Author: allisonwang-db <66...@users.noreply.github.com>
AuthorDate: Fri Nov 20 10:02:26 2020 -0800

    [SPARK-33472][SQL][2.4] Adjust RemoveRedundantSorts rule order
    
    Backport #30373 for branch-2.4.
    
    ### What changes were proposed in this pull request?
    This PR switched the order for the rule `RemoveRedundantSorts` and `EnsureRequirements` so that `EnsureRequirements` will be invoked before `RemoveRedundantSorts` to avoid IllegalArgumentException when instantiating PartitioningCollection.
    
    ### Why are the changes needed?
    `RemoveRedundantSorts` rule uses SparkPlan's `outputPartitioning` to check whether a sort node is redundant. Currently, it is added before `EnsureRequirements`. Since `PartitioningCollection` requires left and right partitioning to have the same number of partitions, which is not necessarily true before applying `EnsureRequirements`, the rule can fail with the following exception:
    ```
    IllegalArgumentException: requirement failed: PartitioningCollection requires all of its partitionings have the same numPartitions.
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Unit test
    
    Closes #30437 from allisonwang-db/spark-33472-2.4.
    
    Authored-by: allisonwang-db <66...@users.noreply.github.com>
    Signed-off-by: Dongjoon Hyun <do...@apache.org>
---
 .../spark/sql/execution/QueryExecution.scala       |  4 +++-
 .../org/apache/spark/sql/execution/SparkPlan.scala |  7 +++++-
 .../sql/execution/RemoveRedundantSortsSuite.scala  | 27 ++++++++++++++++++++++
 3 files changed, 36 insertions(+), 2 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index b92f346..0b9c469 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -97,8 +97,10 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
   /** A sequence of rules that will be applied in order to the physical plan before execution. */
   protected def preparations: Seq[Rule[SparkPlan]] = Seq(
     PlanSubqueries(sparkSession),
-    RemoveRedundantSorts(sparkSession.sessionState.conf),
     EnsureRequirements(sparkSession.sessionState.conf),
+    // `RemoveRedundantSorts` needs to be added before `EnsureRequirements` to guarantee the same
+    // number of partitions when instantiating PartitioningCollection.
+    RemoveRedundantSorts(sparkSession.sessionState.conf),
     CollapseCodegenStages(sparkSession.sessionState.conf),
     ReuseExchange(sparkSession.sessionState.conf),
     ReuseSubquery(sparkSession.sessionState.conf))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 7646f96..28addf6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -91,7 +91,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
   def longMetric(name: String): SQLMetric = metrics(name)
 
   // TODO: Move to `DistributedPlan`
-  /** Specifies how data is partitioned across different nodes in the cluster. */
+  /**
+   * Specifies how data is partitioned across different nodes in the cluster.
+   * Note this method may fail if it is invoked before `EnsureRequirements` is applied
+   * since `PartitioningCollection` requires all its partitionings to have
+   * the same number of partitions.
+   */
   def outputPartitioning: Partitioning = UnknownPartitioning(0) // TODO: WRONG WIDTH!
 
   /**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala
index f7987e2..b82e5cb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala
@@ -18,6 +18,8 @@
 package org.apache.spark.sql.execution
 
 import org.apache.spark.sql.{DataFrame, QueryTest}
+import org.apache.spark.sql.catalyst.plans.physical.{RangePartitioning, UnknownPartitioning}
+import org.apache.spark.sql.execution.joins.SortMergeJoinExec
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
 
@@ -99,4 +101,29 @@ class RemoveRedundantSortsSuite
       }
     }
   }
+
+  test("SPARK-33472: shuffled join with different left and right side partition numbers") {
+    withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+      withTempView("t1", "t2") {
+        spark.range(0, 100, 1, 2).select('id as "key").createOrReplaceTempView("t1")
+        (0 to 100).toDF("key").createOrReplaceTempView("t2")
+
+        val query = """
+          |SELECT t1.key
+          |FROM t1 JOIN t2 ON t1.key = t2.key
+          |WHERE t1.key > 10 AND t2.key < 50
+          |ORDER BY t1.key ASC
+        """.stripMargin
+
+        val df = sql(query)
+        val sparkPlan = df.queryExecution.sparkPlan
+        val join = sparkPlan.collect { case j: SortMergeJoinExec => j }.head
+        val leftPartitioning = join.left.outputPartitioning
+        assert(leftPartitioning.isInstanceOf[RangePartitioning])
+        assert(leftPartitioning.numPartitions == 2)
+        assert(join.right.outputPartitioning == UnknownPartitioning(0))
+        checkSorts(query, 3, 3)
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org