You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yu...@apache.org on 2023/09/04 12:23:56 UTC
[spark] branch master updated: [SPARK-44846][SQL] Convert the lower redundant Aggregate to Project in RemoveRedundantAggregates
This is an automated email from the ASF dual-hosted git repository.
yumwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 32a87f03da7 [SPARK-44846][SQL] Convert the lower redundant Aggregate to Project in RemoveRedundantAggregates
32a87f03da7 is described below
commit 32a87f03da7eef41161a5a7a3aba4a48e0421912
Author: zml1206 <zh...@gmail.com>
AuthorDate: Mon Sep 4 20:23:39 2023 +0800
[SPARK-44846][SQL] Convert the lower redundant Aggregate to Project in RemoveRedundantAggregates
### What changes were proposed in this pull request?
This PR provides a safe way to remove a redundant `Aggregate` in rule `RemoveRedundantAggregates`. Just convert the lower redundant `Aggregate` to `Project`.
### Why are the changes needed?
The aggregate contains complex grouping expressions after `RemoveRedundantAggregates`, if `aggregateExpressions` has (if / case) branches, it is possible that `groupingExpressions` is no longer a subexpression of `aggregateExpressions` after execute `PushFoldableIntoBranches` rule, Then cause `boundReference` error.
For example
```
SELECT c * 2 AS d
FROM (
SELECT if(b > 1, 1, b) AS c
FROM (
SELECT if(a < 0, 0, a) AS b
FROM VALUES (-1), (1), (2) AS t1(a)
) t2
GROUP BY b
) t3
GROUP BY c
```
Before pr
```
== Optimized Logical Plan ==
Aggregate [if ((b#0 > 1)) 1 else b#0], [if ((b#0 > 1)) 2 else (b#0 * 2) AS d#2]
+- Project [if ((a#3 < 0)) 0 else a#3 AS b#0]
+- LocalRelation [a#3]
```
```
== Error ==
Couldn't find b#0 in [if ((b#0 > 1)) 1 else b#0#7]
java.lang.IllegalStateException: Couldn't find b#0 in [if ((b#0 > 1)) 1 else b#0#7]
at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:80)
at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:73)
at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$1(TreeNode.scala:461)
at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(origin.scala:76)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:461)
at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$3(TreeNode.scala:466)
at org.apache.spark.sql.catalyst.trees.BinaryLike.mapChildren(TreeNode.scala:1241)
at org.apache.spark.sql.catalyst.trees.BinaryLike.mapChildren$(TreeNode.scala:1240)
at org.apache.spark.sql.catalyst.expressions.BinaryExpression.mapChildren(Expression.scala:653)
......
```
After pr
```
== Optimized Logical Plan ==
Aggregate [c#1], [(c#1 * 2) AS d#2]
+- Project [if ((b#0 > 1)) 1 else b#0 AS c#1]
+- Project [if ((a#3 < 0)) 0 else a#3 AS b#0]
+- LocalRelation [a#3]
```
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
UT
Closes #42633 from zml1206/SPARK-44846-2.
Authored-by: zml1206 <zh...@gmail.com>
Signed-off-by: Yuming Wang <yu...@ebay.com>
---
.../optimizer/RemoveRedundantAggregates.scala | 19 ++-----------------
.../optimizer/RemoveRedundantAggregatesSuite.scala | 21 ++++++++++++---------
.../sql-tests/analyzer-results/group-by.sql.out | 21 +++++++++++++++++++++
.../test/resources/sql-tests/inputs/group-by.sql | 13 +++++++++++++
.../resources/sql-tests/results/group-by.sql.out | 18 ++++++++++++++++++
5 files changed, 66 insertions(+), 26 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala
index 93f3557a8c8..badf4065f5f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregates.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.optimizer
-import org.apache.spark.sql.catalyst.analysis.PullOutNondeterministic
import org.apache.spark.sql.catalyst.expressions.{AliasHelper, AttributeSet, ExpressionSet}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
@@ -32,22 +31,8 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsPattern(AGGREGATE), ruleId) {
case upper @ Aggregate(_, _, lower: Aggregate) if isLowerRedundant(upper, lower) =>
- val aliasMap = getAliasMap(lower)
-
- val newAggregate = upper.copy(
- child = lower.child,
- groupingExpressions = upper.groupingExpressions.map(replaceAlias(_, aliasMap)),
- aggregateExpressions = upper.aggregateExpressions.map(
- replaceAliasButKeepName(_, aliasMap))
- )
-
- // We might have introduces non-deterministic grouping expression
- if (newAggregate.groupingExpressions.exists(!_.deterministic)) {
- PullOutNondeterministic.applyLocally.applyOrElse(newAggregate, identity[LogicalPlan])
- } else {
- newAggregate
- }
-
+ val projectList = lower.aggregateExpressions.filter(upper.references.contains(_))
+ upper.copy(child = Project(projectList, lower.child))
case agg @ Aggregate(groupingExps, _, child)
if agg.groupOnly && child.distinctKeys.exists(_.subsetOf(ExpressionSet(groupingExps))) =>
Project(agg.aggregateExpressions, child)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala
index 91201b67501..2af3057c0b8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAggregatesSuite.scala
@@ -30,7 +30,8 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("RemoveRedundantAggregates", FixedPoint(10),
- RemoveRedundantAggregates) :: Nil
+ RemoveRedundantAggregates,
+ RemoveNoopOperators) :: Nil
}
private val relation = LocalRelation($"a".int, $"b".int)
@@ -52,6 +53,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
.groupBy($"a")($"a")
.analyze
val expected = relation
+ .select($"a")
.groupBy($"a")($"a")
.analyze
val optimized = Optimize.execute(query)
@@ -67,6 +69,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
.groupBy($"a")($"a")
.analyze
val expected = relation
+ .select($"a")
.groupBy($"a")($"a")
.analyze
val optimized = Optimize.execute(query)
@@ -80,6 +83,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
.groupBy($"a")($"a")
.analyze
val expected = relation
+ .select($"a")
.groupBy($"a")($"a")
.analyze
val optimized = Optimize.execute(query)
@@ -93,7 +97,8 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
.groupBy($"c")($"c")
.analyze
val expected = relation
- .groupBy($"a" + $"b")(($"a" + $"b") as "c")
+ .select(($"a" + $"b") as "c")
+ .groupBy($"c")($"c")
.analyze
val optimized = Optimize.execute(query)
comparePlans(optimized, expected)
@@ -106,6 +111,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
.groupBy($"a")($"a", rand(0) as "c")
.analyze
val expected = relation
+ .select($"a")
.groupBy($"a")($"a", rand(0) as "c")
.analyze
val optimized = Optimize.execute(query)
@@ -118,7 +124,9 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
.groupBy($"a", $"c")($"a", $"c")
.analyze
val expected = relation
- .groupBy($"a", $"c")($"a", rand(0) as "c")
+ .select($"a", $"b", rand(0) as "_nondeterministic")
+ .select($"a", $"_nondeterministic" as "c")
+ .groupBy($"a", $"c")($"a", $"c")
.analyze
val optimized = Optimize.execute(query)
comparePlans(optimized, expected)
@@ -151,7 +159,7 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
test("Remove redundant aggregate - upper has contains foldable expressions") {
val originalQuery = x.groupBy($"a", $"b")($"a", $"b").groupBy($"a")($"a", TrueLiteral).analyze
- val correctAnswer = x.groupBy($"a")($"a", TrueLiteral).analyze
+ val correctAnswer = x.select($"a").groupBy($"a")($"a", TrueLiteral).analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
}
@@ -174,7 +182,6 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
.analyze
val expected = relation
.groupBy($"a")($"a", ($"a" + rand(0)) as "c")
- .select($"a", $"c")
.analyze
val optimized = Optimize.execute(query)
comparePlans(optimized, expected)
@@ -187,7 +194,6 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
.groupBy("x.a".attr, "x.b".attr)("x.a".attr, "x.b".attr)
val correctAnswer = x.groupBy($"a", $"b")($"a", $"b")
.join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr))
- .select("x.a".attr, "x.b".attr)
val optimized = Optimize.execute(originalQuery.analyze)
comparePlans(optimized, correctAnswer.analyze)
@@ -201,7 +207,6 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
.groupBy("x.a".attr, "d".attr)("x.a".attr, "d".attr)
val correctAnswer = x.groupBy($"a", $"b")($"a", $"b".as("d"))
.join(y, joinType, Some("x.a".attr === "y.a".attr && "d".attr === "y.b".attr))
- .select("x.a".attr, "d".attr)
val optimized = Optimize.execute(originalQuery.analyze)
comparePlans(optimized, correctAnswer.analyze)
@@ -231,7 +236,6 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
.groupBy("x.a".attr, "x.b".attr)("x.a".attr)
val correctAnswer = x.groupBy($"a", $"b")($"a", $"b")
.join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr))
- .select("x.a".attr, "x.b".attr)
.join(y, joinType, Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr))
.select("x.a".attr)
@@ -247,7 +251,6 @@ class RemoveRedundantAggregatesSuite extends PlanTest {
.analyze
val correctAnswer = relation
.groupBy($"a")($"a", count($"b").as("cnt"))
- .select($"a", $"cnt")
.analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out
index 67e1ddd32ea..202ceee1804 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out
@@ -1175,3 +1175,24 @@ Sort [a#x ASC NULLS FIRST], true
+- Project [a#x, b#x]
+- SubqueryAlias testData
+- LocalRelation [a#x, b#x]
+
+
+-- !query
+SELECT c * 2 AS d
+FROM (
+ SELECT if(b > 1, 1, b) AS c
+ FROM (
+ SELECT if(a < 0, 0, a) AS b
+ FROM VALUES (-1), (1), (2) AS t1(a)
+ ) t2
+ GROUP BY b
+ ) t3
+GROUP BY c
+-- !query analysis
+Aggregate [c#x], [(c#x * 2) AS d#x]
++- SubqueryAlias t3
+ +- Aggregate [b#x], [if ((b#x > 1)) 1 else b#x AS c#x]
+ +- SubqueryAlias t2
+ +- Project [if ((a#x < 0)) 0 else a#x AS b#x]
+ +- SubqueryAlias t1
+ +- LocalRelation [a#x]
diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
index c812403ba2c..c35cdb0de27 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
@@ -251,3 +251,16 @@ GROUP BY a;
SELECT mode(a), mode(b) FROM testData;
SELECT a, mode(b) FROM testData GROUP BY a ORDER BY a;
+
+
+-- SPARK-44846: PushFoldableIntoBranches in complex grouping expressions cause bindReference error
+SELECT c * 2 AS d
+FROM (
+ SELECT if(b > 1, 1, b) AS c
+ FROM (
+ SELECT if(a < 0, 0, a) AS b
+ FROM VALUES (-1), (1), (2) AS t1(a)
+ ) t2
+ GROUP BY b
+ ) t3
+GROUP BY c;
diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
index acdf8d5a854..db79646fe43 100644
--- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
@@ -1103,3 +1103,21 @@ NULL 1
1 1
2 1
3 1
+
+
+-- !query
+SELECT c * 2 AS d
+FROM (
+ SELECT if(b > 1, 1, b) AS c
+ FROM (
+ SELECT if(a < 0, 0, a) AS b
+ FROM VALUES (-1), (1), (2) AS t1(a)
+ ) t2
+ GROUP BY b
+ ) t3
+GROUP BY c
+-- !query schema
+struct<d:int>
+-- !query output
+0
+2
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org