You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2020/05/16 05:34:05 UTC
[spark] branch branch-2.4 updated: [SPARK-31663][SQL] Grouping sets
with having clause returns the wrong result
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.4 by this push:
new a4885f3 [SPARK-31663][SQL] Grouping sets with having clause returns the wrong result
a4885f3 is described below
commit a4885f3654899bcb852183af70cc0a82e7dd81d0
Author: Yuanjian Li <xy...@gmail.com>
AuthorDate: Sat May 16 04:37:18 2020 +0000
[SPARK-31663][SQL] Grouping sets with having clause returns the wrong result
- Resolve the havingcondition with expanding the GROUPING SETS/CUBE/ROLLUP expressions together in `ResolveGroupingAnalytics`:
- Change the operations resolving directions to top-down.
- Try resolving the condition of the filter as though it is in the aggregate clause by reusing the function in `ResolveAggregateFunctions`
- Push the aggregate expressions into the aggregate which contains the expanded operations.
- Use UnresolvedHaving for all having clause.
Correctness bug fix. See the demo and analysis in SPARK-31663.
Yes, correctness bug fix for HAVING with GROUPING SETS.
New UTs added.
Closes #28501 from xuanyuanking/SPARK-31663.
Authored-by: Yuanjian Li <xy...@gmail.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
(cherry picked from commit 86bd37f37eb1e534c520dc9a02387debf9fa05a1)
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../spark/sql/catalyst/analysis/Analyzer.scala | 136 ++++++++++++++++-----
.../spark/sql/catalyst/analysis/unresolved.scala | 7 +-
.../apache/spark/sql/catalyst/dsl/package.scala | 2 +-
.../spark/sql/catalyst/parser/AstBuilder.scala | 7 +-
.../src/test/resources/sql-tests/inputs/having.sql | 7 +-
.../resources/sql-tests/results/having.sql.out | 28 ++++-
6 files changed, 146 insertions(+), 41 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index f10276d..43ab651 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -431,20 +431,13 @@ class Analyzer(
}.asInstanceOf[NamedExpression]
}
- /*
- * Construct [[Aggregate]] operator from Cube/Rollup/GroupingSets.
- */
- private def constructAggregate(
+ private def getFinalGroupByExpressions(
selectedGroupByExprs: Seq[Seq[Expression]],
- groupByExprs: Seq[Expression],
- aggregationExprs: Seq[NamedExpression],
- child: LogicalPlan): LogicalPlan = {
- val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
-
+ groupByExprs: Seq[Expression]): Seq[Expression] = {
// In case of ANSI-SQL compliant syntax for GROUPING SETS, groupByExprs is optional and
// can be null. In such case, we derive the groupByExprs from the user supplied values for
// grouping sets.
- val finalGroupByExpressions = if (groupByExprs == Nil) {
+ if (groupByExprs == Nil) {
selectedGroupByExprs.flatten.foldLeft(Seq.empty[Expression]) { (result, currentExpr) =>
// Only unique expressions are included in the group by expressions and is determined
// based on their semantic equality. Example. grouping sets ((a * b), (b * a)) results
@@ -458,6 +451,18 @@ class Analyzer(
} else {
groupByExprs
}
+ }
+
+ /*
+ * Construct [[Aggregate]] operator from Cube/Rollup/GroupingSets.
+ */
+ private def constructAggregate(
+ selectedGroupByExprs: Seq[Seq[Expression]],
+ groupByExprs: Seq[Expression],
+ aggregationExprs: Seq[NamedExpression],
+ child: LogicalPlan): LogicalPlan = {
+ val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
+ val finalGroupByExpressions = getFinalGroupByExpressions(selectedGroupByExprs, groupByExprs)
// Expand works by setting grouping expressions to null as determined by the
// `selectedGroupByExprs`. To prevent these null values from being used in an aggregate
@@ -489,8 +494,70 @@ class Analyzer(
}
}
- // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
+ private def tryResolveHavingCondition(h: UnresolvedHaving): LogicalPlan = {
+ val aggForResolving = h.child match {
+ // For CUBE/ROLLUP expressions, to avoid resolving repeatedly, here we delete them from
+ // groupingExpressions for condition resolving.
+ case a @ Aggregate(Seq(c @ Cube(groupByExprs)), _, _) =>
+ a.copy(groupingExpressions = groupByExprs)
+ case a @ Aggregate(Seq(r @ Rollup(groupByExprs)), _, _) =>
+ a.copy(groupingExpressions = groupByExprs)
+ case g: GroupingSets =>
+ Aggregate(
+ getFinalGroupByExpressions(g.selectedGroupByExprs, g.groupByExprs),
+ g.aggregations, g.child)
+ }
+ // Try resolving the condition of the filter as though it is in the aggregate clause
+ val resolvedInfo =
+ ResolveAggregateFunctions.resolveFilterCondInAggregate(h.havingCondition, aggForResolving)
+
+ // Push the aggregate expressions into the aggregate (if any).
+ if (resolvedInfo.nonEmpty) {
+ val (extraAggExprs, resolvedHavingCond) = resolvedInfo.get
+ val newChild = h.child match {
+ case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) =>
+ constructAggregate(
+ cubeExprs(groupByExprs), groupByExprs, aggregateExpressions ++ extraAggExprs, child)
+ case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) =>
+ constructAggregate(
+ rollupExprs(groupByExprs), groupByExprs, aggregateExpressions ++ extraAggExprs, child)
+ case x: GroupingSets =>
+ constructAggregate(
+ x.selectedGroupByExprs, x.groupByExprs, x.aggregations ++ extraAggExprs, x.child)
+ }
+
+ // Since the exprId of extraAggExprs will be changed in the constructed aggregate, and the
+ // aggregateExpressions keeps the input order. So here we build an exprMap to resolve the
+ // condition again.
+ val exprMap = extraAggExprs.zip(
+ newChild.asInstanceOf[Aggregate].aggregateExpressions.takeRight(
+ extraAggExprs.length)).toMap
+ val newCond = resolvedHavingCond.transform {
+ case ne: NamedExpression if exprMap.contains(ne) => exprMap(ne)
+ }
+ Project(newChild.output.dropRight(extraAggExprs.length),
+ Filter(newCond, newChild))
+ } else {
+ h
+ }
+ }
+
+ // This require transformDown to resolve having condition when generating aggregate node for
+ // CUBE/ROLLUP/GROUPING SETS. This also replace grouping()/grouping_id() in resolved
+ // Filter/Sort.
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown {
+ case h @ UnresolvedHaving(
+ _, agg @ Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, _))
+ if agg.childrenResolved && (groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
+ tryResolveHavingCondition(h)
+ case h @ UnresolvedHaving(
+ _, agg @ Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, _))
+ if agg.childrenResolved && (groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
+ tryResolveHavingCondition(h)
+ case h @ UnresolvedHaving(_, g: GroupingSets)
+ if g.childrenResolved && g.expressions.forall(_.resolved) =>
+ tryResolveHavingCondition(h)
+
case a if !a.childrenResolved => a // be sure all of the children are resolved.
// Ensure group by expressions and aggregate expressions have been resolved.
@@ -964,7 +1031,7 @@ class Analyzer(
case plan if containsDeserializer(plan.expressions) => plan
// Skip the having clause here, this will be handled in ResolveAggregateFunctions.
- case h: AggregateWithHaving => h
+ case h: UnresolvedHaving => h
case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
@@ -1542,7 +1609,7 @@ class Analyzer(
// Resolve aggregate with having clause to Filter(..., Aggregate()). Note, to avoid wrongly
// resolve the having condition expression, here we skip resolving it in ResolveReferences
// and transform it to Filter after aggregate is resolved. See more details in SPARK-31519.
- case AggregateWithHaving(cond, agg: Aggregate) if agg.resolved =>
+ case UnresolvedHaving(cond, agg: Aggregate) if agg.resolved =>
resolveHaving(Filter(cond, agg), agg)
case f @ Filter(_, agg: Aggregate) if agg.resolved =>
@@ -1618,13 +1685,13 @@ class Analyzer(
condition.find(_.isInstanceOf[AggregateExpression]).isDefined
}
- def resolveHaving(filter: Filter, agg: Aggregate): LogicalPlan = {
- // Try resolving the condition of the filter as though it is in the aggregate clause
+ def resolveFilterCondInAggregate(
+ filterCond: Expression, agg: Aggregate): Option[(Seq[NamedExpression], Expression)] = {
try {
val aggregatedCondition =
Aggregate(
agg.groupingExpressions,
- Alias(filter.condition, "havingCondition")() :: Nil,
+ Alias(filterCond, "havingCondition")() :: Nil,
agg.child)
val resolvedOperator = executeSameContext(aggregatedCondition)
def resolvedAggregateFilter =
@@ -1656,22 +1723,33 @@ class Analyzer(
alias.toAttribute
}
}
-
- // Push the aggregate expressions into the aggregate (if any).
if (aggregateExpressions.nonEmpty) {
- Project(agg.output,
- Filter(transformedAggregateFilter,
- agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions)))
+ Some((aggregateExpressions, transformedAggregateFilter))
} else {
- filter
+ None
}
} else {
- filter
+ None
}
} catch {
- // Attempting to resolve in the aggregate can result in ambiguity. When this happens,
- // just return the original plan.
- case ae: AnalysisException => filter
+ // Attempting to resolve in the aggregate can result in ambiguity. When this happens,
+ // just return None and the caller side will return the original plan.
+ case ae: AnalysisException => None
+ }
+ }
+
+ def resolveHaving(filter: Filter, agg: Aggregate): LogicalPlan = {
+ // Try resolving the condition of the filter as though it is in the aggregate clause
+ val resolvedInfo = resolveFilterCondInAggregate(filter.condition, agg)
+
+ // Push the aggregate expressions into the aggregate (if any).
+ if (resolvedInfo.nonEmpty) {
+ val (aggregateExpressions, resolvedHavingCond) = resolvedInfo.get
+ Project(agg.output,
+ Filter(resolvedHavingCond,
+ agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions)))
+ } else {
+ filter
}
}
}
@@ -2064,12 +2142,12 @@ class Analyzer(
case Filter(condition, _) if hasWindowFunction(condition) =>
failAnalysis("It is not allowed to use window functions inside WHERE clause")
- case AggregateWithHaving(condition, _) if hasWindowFunction(condition) =>
+ case UnresolvedHaving(condition, _) if hasWindowFunction(condition) =>
failAnalysis("It is not allowed to use window functions inside HAVING clause")
// Aggregate with Having clause. This rule works with an unresolved Aggregate because
// a resolved Aggregate will not have Window Functions.
- case f @ AggregateWithHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
+ case f @ UnresolvedHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
if child.resolved &&
hasWindowFunction(aggregateExprs) &&
a.expressions.forall(_.resolved) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index bcd2ff7..3708ee9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -515,11 +515,12 @@ case class UnresolvedOrdinal(ordinal: Int)
}
/**
- * Represents unresolved aggregate with having clause, it is turned by the analyzer into a Filter.
+ * Represents unresolved having clause, the child for it can be Aggregate, GroupingSets, Rollup
+ * and Cube. It is turned by the analyzer into a Filter.
*/
-case class AggregateWithHaving(
+case class UnresolvedHaving(
havingCondition: Expression,
- child: Aggregate)
+ child: LogicalPlan)
extends UnaryNode {
override lazy val resolved: Boolean = false
override def output: Seq[Attribute] = child.output
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index f7b1638..3dd0af5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -360,7 +360,7 @@ package object dsl {
groupingExprs: Expression*)(
aggregateExprs: Expression*)(
havingCondition: Expression): LogicalPlan = {
- AggregateWithHaving(havingCondition,
+ UnresolvedHaving(havingCondition,
groupBy(groupingExprs: _*)(aggregateExprs: _*).asInstanceOf[Aggregate])
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 22d5f1d..e2e8a45 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -401,12 +401,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
case p: Predicate => p
case e => Cast(e, BooleanType)
}
- plan match {
- case aggregate: Aggregate =>
- AggregateWithHaving(predicate, aggregate)
- case _ =>
- Filter(predicate, plan)
- }
+ UnresolvedHaving(predicate, plan)
}
diff --git a/sql/core/src/test/resources/sql-tests/inputs/having.sql b/sql/core/src/test/resources/sql-tests/inputs/having.sql
index 179686e..ccc633c 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/having.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/having.sql
@@ -18,4 +18,9 @@ SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0);
SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > 1;
-- SPARK-31519: Datetime functions in having aggregate expressions returns the wrong result
-SELECT SUM(a) AS b, hour('2020-01-01 12:12:12') AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10
+SELECT SUM(a) AS b, hour('2020-01-01 12:12:12') AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10;
+
+-- SPARK-31663: Grouping sets with having clause returns the wrong result
+SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY GROUPING SETS ((b), (a, b)) HAVING b > 10;
+SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY CUBE(a, b) HAVING b > 10;
+SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY ROLLUP(a, b) HAVING b > 10;
diff --git a/sql/core/src/test/resources/sql-tests/results/having.sql.out b/sql/core/src/test/resources/sql-tests/results/having.sql.out
index fd594f5..79d6416 100644
--- a/sql/core/src/test/resources/sql-tests/results/having.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/having.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 6
+-- Number of queries: 9
-- !query 0
@@ -55,3 +55,29 @@ SELECT SUM(a) AS b, hour('2020-01-01 12:12:12') AS fake FROM VALUES (1, 10), (2,
struct<b:bigint,fake:int>
-- !query 5 output
2 12
+
+
+-- !query 6
+SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY GROUPING SETS ((b), (a, b)) HAVING b > 10
+-- !query 6 schema
+struct<b:bigint>
+-- !query 6 output
+2
+2
+
+
+-- !query 7
+SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY CUBE(a, b) HAVING b > 10
+-- !query 7 schema
+struct<b:bigint>
+-- !query 7 output
+2
+2
+
+
+-- !query 8
+SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY ROLLUP(a, b) HAVING b > 10
+-- !query 8 schema
+struct<b:bigint>
+-- !query 8 output
+2
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org