You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yu...@apache.org on 2023/09/04 12:23:56 UTC

[spark] branch master updated: [SPARK-44846][SQL] Convert the lower redundant Aggregate to Project in RemoveRedundantAggregates

This is an automated email from the ASF dual-hosted git repository.

yumwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 32a87f03da7 [SPARK-44846][SQL] Convert the lower redundant Aggregate to Project in RemoveRedundantAggregates
32a87f03da7 is described below

commit 32a87f03da7eef41161a5a7a3aba4a48e0421912
Author: zml1206 <zh...@gmail.com>
AuthorDate: Mon Sep 4 20:23:39 2023 +0800

    [SPARK-44846][SQL] Convert the lower redundant Aggregate to Project in RemoveRedundantAggregates
    
    ### What changes were proposed in this pull request?
    This PR provides a safe way to remove a redundant `Aggregate` in rule `RemoveRedundantAggregates`. Just convert the lower redundant `Aggregate` to `Project`.
    
    ### Why are the changes needed?
    The aggregate contains complex grouping expressions after `RemoveRedundantAggregates`, if `aggregateExpressions` has (if / case) branches, it is possible that `groupingExpressions` is no longer a subexpression of `aggregateExpressions` after execute `PushFoldableIntoBranches` rule, Then cause `boundReference` error.
    For example
    ```
    SELECT c * 2 AS d
    FROM (
             SELECT if(b > 1, 1, b) AS c
             FROM (
                      SELECT if(a < 0, 0, a) AS b
                      FROM VALUES (-1), (1), (2) AS t1(a)
                  ) t2
             GROUP BY b
         ) t3
    GROUP BY c
    ```
    Before pr
    ```
    == Optimized Logical Plan ==
    Aggregate [if ((b#0 > 1)) 1 else b#0], [if ((b#0 > 1)) 2 else (b#0 * 2) AS d#2]
    +- Project [if ((a#3 < 0)) 0 else a#3 AS b#0]
       +- LocalRelation [a#3]
    ```
    ```
    == Error ==
    Couldn't find b#0 in [if ((b#0 > 1)) 1 else b#0#7]
    java.lang.IllegalStateException: Couldn't find b#0 in [if ((b#0 > 1)) 1 else b#0#7]
            at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:80)
            at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:73)
            at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$1(TreeNode.scala:461)
            at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(origin.scala:76)
            at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:461)
            at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$3(TreeNode.scala:466)
            at org.apache.spark.sql.catalyst.trees.BinaryLike.mapChildren(TreeNode.scala:1241)
            at org.apache.spark.sql.catalyst.trees.BinaryLike.mapChildren$(TreeNode.scala:1240)
            at org.apache.spark.sql.catalyst.expressions.BinaryExpression.mapChildren(Expression.scala:653)
            ......
    ```
    After pr
    ```
    == Optimized Logical Plan ==
    Aggregate [c#1], [(c#1 * 2) AS d#2]
    +- Project [if ((b#0 > 1)) 1 else b#0 AS c#1]
       +- Project [if ((a#3 < 0)) 0 else a#3 AS b#0]
          +- LocalRelation [a#3]
    ```
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    UT
    
    Closes #42633 from zml1206/SPARK-44846-2.
    
    Authored-by: zml1206 <zh...@gmail.com>
    Signed-off-by: Yuming Wang <yu...@ebay.com>
---
 .../optimizer/RemoveRedundantAggregates.scala       | 19 ++-----------------
 .../optimizer/RemoveRedundantAggregatesSuite.scala  | 21 ++++++++++++---------
 .../sql-tests/analyzer-results/group-by.sql.out     | 21 +++++++++++++++++++++
 .../test/resources/sql-tests/inputs/group-by.sql    | 13 +++++++++++++
 .../resources/sql-tests/results/group-by.sql.out    | 18 ++++++++++++++++++
 5 files changed, 66 insertions(+), 26 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala
index 93f3557a8c8..badf4065f5f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.sql.catalyst.optimizer
 
-import org.apache.spark.sql.catalyst.analysis.PullOutNondeterministic
 import org.apache.spark.sql.catalyst.expressions.{AliasHelper, AttributeSet, ExpressionSet}
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
@@ -32,22 +31,8 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
   def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
     _.containsPattern(AGGREGATE), ruleId) {
     case upper @ Aggregate(_, _, lower: Aggregate) if isLowerRedundant(upper, lower) =>
-      val aliasMap = getAliasMap(lower)
-
-      val newAggregate = upper.copy(
-        child = lower.child,
-        groupingExpressions = upper.groupingExpressions.map(replaceAlias(_, aliasMap)),
-        aggregateExpressions = upper.aggregateExpressions.map(
-          replaceAliasButKeepName(_, aliasMap))
-      )
-
-      // We might have introduces non-deterministic grouping expression
-      if (newAggregate.groupingExpressions.exists(!_.deterministic)) {
-        PullOutNondeterministic.applyLocally.applyOrElse(newAggregate, identity[LogicalPlan])
-      } else {
-        newAggregate
-      }
-
+      val projectList = lower.aggregateExpressions.filter(upper.references.contains(_))
+      upper.copy(child = Project(projectList, lower.child))
     case agg @ Aggregate(groupingExps, _, child)
         if agg.groupOnly && child.distinctKeys.exists(_.subsetOf(ExpressionSet(groupingExps))) =>
       Project(agg.aggregateExpressions, child)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala
index 91201b67501..2af3057c0b8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala
@@ -30,7 +30,8 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
 
   object Optimize extends RuleExecutor[LogicalPlan] {
     val batches = Batch("RemoveRedundantAggregates", FixedPoint(10),
-      RemoveRedundantAggregates) :: Nil
+      RemoveRedundantAggregates,
+      RemoveNoopOperators) :: Nil
   }
 
   private val relation = LocalRelation($"a".int, $"b".int)
@@ -52,6 +53,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
         .groupBy($"a")($"a")
         .analyze
       val expected = relation
+        .select($"a")
         .groupBy($"a")($"a")
         .analyze
       val optimized = Optimize.execute(query)
@@ -67,6 +69,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
         .groupBy($"a")($"a")
         .analyze
       val expected = relation
+        .select($"a")
         .groupBy($"a")($"a")
         .analyze
       val optimized = Optimize.execute(query)
@@ -80,6 +83,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
       .groupBy($"a")($"a")
       .analyze
     val expected = relation
+      .select($"a")
       .groupBy($"a")($"a")
       .analyze
     val optimized = Optimize.execute(query)
@@ -93,7 +97,8 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
         .groupBy($"c")($"c")
         .analyze
       val expected = relation
-        .groupBy($"a" + $"b")(($"a" + $"b") as "c")
+        .select(($"a" + $"b") as "c")
+        .groupBy($"c")($"c")
         .analyze
       val optimized = Optimize.execute(query)
       comparePlans(optimized, expected)
@@ -106,6 +111,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
       .groupBy($"a")($"a", rand(0) as "c")
       .analyze
     val expected = relation
+      .select($"a")
       .groupBy($"a")($"a", rand(0) as "c")
       .analyze
     val optimized = Optimize.execute(query)
@@ -118,7 +124,9 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
       .groupBy($"a", $"c")($"a", $"c")
       .analyze
     val expected = relation
-      .groupBy($"a", $"c")($"a", rand(0) as "c")
+      .select($"a", $"b", rand(0) as "_nondeterministic")
+      .select($"a", $"_nondeterministic" as "c")
+      .groupBy($"a", $"c")($"a", $"c")
       .analyze
     val optimized = Optimize.execute(query)
     comparePlans(optimized, expected)
@@ -151,7 +159,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
 
   test("Remove redundant aggregate - upper has contains foldable expressions") {
     val originalQuery = x.groupBy($"a", $"b")($"a", $"b").groupBy($"a")($"a", TrueLiteral).analyze
-    val correctAnswer = x.groupBy($"a")($"a", TrueLiteral).analyze
+    val correctAnswer = x.select($"a").groupBy($"a")($"a", TrueLiteral).analyze
     val optimized = Optimize.execute(originalQuery)
     comparePlans(optimized, correctAnswer)
   }
@@ -174,7 +182,6 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
       .analyze
     val expected = relation
       .groupBy($"a")($"a", ($"a" + rand(0)) as "c")
-      .select($"a", $"c")
       .analyze
     val optimized = Optimize.execute(query)
     comparePlans(optimized, expected)
@@ -187,7 +194,6 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
         .groupBy("x.a".attr, "x.b".attr)("x.a".attr, "x.b".attr)
       val correctAnswer = x.groupBy($"a", $"b")($"a", $"b")
         .join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr))
-        .select("x.a".attr, "x.b".attr)
 
       val optimized = Optimize.execute(originalQuery.analyze)
       comparePlans(optimized, correctAnswer.analyze)
@@ -201,7 +207,6 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
         .groupBy("x.a".attr, "d".attr)("x.a".attr, "d".attr)
       val correctAnswer = x.groupBy($"a", $"b")($"a", $"b".as("d"))
         .join(y, joinType, Some("x.a".attr === "y.a".attr && "d".attr === "y.b".attr))
-        .select("x.a".attr, "d".attr)
 
       val optimized = Optimize.execute(originalQuery.analyze)
       comparePlans(optimized, correctAnswer.analyze)
@@ -231,7 +236,6 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
         .groupBy("x.a".attr, "x.b".attr)("x.a".attr)
       val correctAnswer = x.groupBy($"a", $"b")($"a", $"b")
         .join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr))
-        .select("x.a".attr, "x.b".attr)
         .join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr))
         .select("x.a".attr)
 
@@ -247,7 +251,6 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
       .analyze
     val correctAnswer = relation
       .groupBy($"a")($"a", count($"b").as("cnt"))
-      .select($"a", $"cnt")
       .analyze
     val optimized = Optimize.execute(originalQuery)
     comparePlans(optimized, correctAnswer)
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out
index 67e1ddd32ea..202ceee1804 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out
@@ -1175,3 +1175,24 @@ Sort [a#x ASC NULLS FIRST], true
             +- Project [a#x, b#x]
                +- SubqueryAlias testData
                   +- LocalRelation [a#x, b#x]
+
+
+-- !query
+SELECT c * 2 AS d
+FROM (
+         SELECT if(b > 1, 1, b) AS c
+         FROM (
+                  SELECT if(a < 0, 0, a) AS b
+                  FROM VALUES (-1), (1), (2) AS t1(a)
+              ) t2
+         GROUP BY b
+     ) t3
+GROUP BY c
+-- !query analysis
+Aggregate [c#x], [(c#x * 2) AS d#x]
++- SubqueryAlias t3
+   +- Aggregate [b#x], [if ((b#x > 1)) 1 else b#x AS c#x]
+      +- SubqueryAlias t2
+         +- Project [if ((a#x < 0)) 0 else a#x AS b#x]
+            +- SubqueryAlias t1
+               +- LocalRelation [a#x]
diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
index c812403ba2c..c35cdb0de27 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
@@ -251,3 +251,16 @@ GROUP BY a;
 
 SELECT mode(a), mode(b) FROM testData;
 SELECT a, mode(b) FROM testData GROUP BY a ORDER BY a;
+
+
+-- SPARK-44846: PushFoldableIntoBranches in complex grouping expressions cause bindReference error
+SELECT c * 2 AS d
+FROM (
+         SELECT if(b > 1, 1, b) AS c
+         FROM (
+                  SELECT if(a < 0, 0, a) AS b
+                  FROM VALUES (-1), (1), (2) AS t1(a)
+              ) t2
+         GROUP BY b
+     ) t3
+GROUP BY c;
diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
index acdf8d5a854..db79646fe43 100644
--- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
@@ -1103,3 +1103,21 @@ NULL	1
 1	1
 2	1
 3	1
+
+
+-- !query
+SELECT c * 2 AS d
+FROM (
+         SELECT if(b > 1, 1, b) AS c
+         FROM (
+                  SELECT if(a < 0, 0, a) AS b
+                  FROM VALUES (-1), (1), (2) AS t1(a)
+              ) t2
+         GROUP BY b
+     ) t3
+GROUP BY c
+-- !query schema
+struct<d:int>
+-- !query output
+0
+2


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org