You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2020/05/16 05:34:05 UTC

[spark] branch branch-2.4 updated: [SPARK-31663][SQL] Grouping sets with having clause returns the wrong result

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new a4885f3  [SPARK-31663][SQL] Grouping sets with having clause returns the wrong result
a4885f3 is described below

commit a4885f3654899bcb852183af70cc0a82e7dd81d0
Author: Yuanjian Li <xy...@gmail.com>
AuthorDate: Sat May 16 04:37:18 2020 +0000

    [SPARK-31663][SQL] Grouping sets with having clause returns the wrong result
    
    - Resolve the havingcondition with expanding the GROUPING SETS/CUBE/ROLLUP expressions together in `ResolveGroupingAnalytics`:
        - Change the operations resolving directions to top-down.
        - Try resolving the condition of the filter as though it is in the aggregate clause by reusing the function in `ResolveAggregateFunctions`
        - Push the aggregate expressions into the aggregate which contains the expanded operations.
    - Use UnresolvedHaving for all having clause.
    
    Correctness bug fix. See the demo and analysis in SPARK-31663.
    
    Yes, correctness bug fix for HAVING with GROUPING SETS.
    
    New UTs added.
    
    Closes #28501 from xuanyuanking/SPARK-31663.
    
    Authored-by: Yuanjian Li <xy...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit 86bd37f37eb1e534c520dc9a02387debf9fa05a1)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     | 136 ++++++++++++++++-----
 .../spark/sql/catalyst/analysis/unresolved.scala   |   7 +-
 .../apache/spark/sql/catalyst/dsl/package.scala    |   2 +-
 .../spark/sql/catalyst/parser/AstBuilder.scala     |   7 +-
 .../src/test/resources/sql-tests/inputs/having.sql |   7 +-
 .../resources/sql-tests/results/having.sql.out     |  28 ++++-
 6 files changed, 146 insertions(+), 41 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index f10276d..43ab651 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -431,20 +431,13 @@ class Analyzer(
       }.asInstanceOf[NamedExpression]
     }
 
-    /*
-     * Construct [[Aggregate]] operator from Cube/Rollup/GroupingSets.
-     */
-    private def constructAggregate(
+    private def getFinalGroupByExpressions(
         selectedGroupByExprs: Seq[Seq[Expression]],
-        groupByExprs: Seq[Expression],
-        aggregationExprs: Seq[NamedExpression],
-        child: LogicalPlan): LogicalPlan = {
-      val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
-
+        groupByExprs: Seq[Expression]): Seq[Expression] = {
       // In case of ANSI-SQL compliant syntax for GROUPING SETS, groupByExprs is optional and
       // can be null. In such case, we derive the groupByExprs from the user supplied values for
       // grouping sets.
-      val finalGroupByExpressions = if (groupByExprs == Nil) {
+      if (groupByExprs == Nil) {
         selectedGroupByExprs.flatten.foldLeft(Seq.empty[Expression]) { (result, currentExpr) =>
           // Only unique expressions are included in the group by expressions and is determined
           // based on their semantic equality. Example. grouping sets ((a * b), (b * a)) results
@@ -458,6 +451,18 @@ class Analyzer(
       } else {
         groupByExprs
       }
+    }
+
+    /*
+     * Construct [[Aggregate]] operator from Cube/Rollup/GroupingSets.
+     */
+    private def constructAggregate(
+        selectedGroupByExprs: Seq[Seq[Expression]],
+        groupByExprs: Seq[Expression],
+        aggregationExprs: Seq[NamedExpression],
+        child: LogicalPlan): LogicalPlan = {
+      val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)()
+      val finalGroupByExpressions = getFinalGroupByExpressions(selectedGroupByExprs, groupByExprs)
 
       // Expand works by setting grouping expressions to null as determined by the
       // `selectedGroupByExprs`. To prevent these null values from being used in an aggregate
@@ -489,8 +494,70 @@ class Analyzer(
       }
     }
 
-    // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort
-    def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
+    private def tryResolveHavingCondition(h: UnresolvedHaving): LogicalPlan = {
+      val aggForResolving = h.child match {
+        // For CUBE/ROLLUP expressions, to avoid resolving repeatedly, here we delete them from
+        // groupingExpressions for condition resolving.
+        case a @ Aggregate(Seq(c @ Cube(groupByExprs)), _, _) =>
+          a.copy(groupingExpressions = groupByExprs)
+        case a @ Aggregate(Seq(r @ Rollup(groupByExprs)), _, _) =>
+          a.copy(groupingExpressions = groupByExprs)
+        case g: GroupingSets =>
+          Aggregate(
+            getFinalGroupByExpressions(g.selectedGroupByExprs, g.groupByExprs),
+            g.aggregations, g.child)
+      }
+      // Try resolving the condition of the filter as though it is in the aggregate clause
+      val resolvedInfo =
+        ResolveAggregateFunctions.resolveFilterCondInAggregate(h.havingCondition, aggForResolving)
+
+      // Push the aggregate expressions into the aggregate (if any).
+      if (resolvedInfo.nonEmpty) {
+        val (extraAggExprs, resolvedHavingCond) = resolvedInfo.get
+        val newChild = h.child match {
+          case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) =>
+            constructAggregate(
+              cubeExprs(groupByExprs), groupByExprs, aggregateExpressions ++ extraAggExprs, child)
+          case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) =>
+            constructAggregate(
+              rollupExprs(groupByExprs), groupByExprs, aggregateExpressions ++ extraAggExprs, child)
+          case x: GroupingSets =>
+            constructAggregate(
+              x.selectedGroupByExprs, x.groupByExprs, x.aggregations ++ extraAggExprs, x.child)
+        }
+
+        // Since the exprId of extraAggExprs will be changed in the constructed aggregate, and the
+        // aggregateExpressions keeps the input order. So here we build an exprMap to resolve the
+        // condition again.
+        val exprMap = extraAggExprs.zip(
+          newChild.asInstanceOf[Aggregate].aggregateExpressions.takeRight(
+            extraAggExprs.length)).toMap
+        val newCond = resolvedHavingCond.transform {
+          case ne: NamedExpression if exprMap.contains(ne) => exprMap(ne)
+        }
+        Project(newChild.output.dropRight(extraAggExprs.length),
+          Filter(newCond, newChild))
+      } else {
+        h
+      }
+    }
+
+    // This require transformDown to resolve having condition when generating aggregate node for
+    // CUBE/ROLLUP/GROUPING SETS. This also replace grouping()/grouping_id() in resolved
+    // Filter/Sort.
+    def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown {
+      case h @ UnresolvedHaving(
+          _, agg @ Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, _))
+          if agg.childrenResolved && (groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
+        tryResolveHavingCondition(h)
+      case h @ UnresolvedHaving(
+          _, agg @ Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, _))
+          if agg.childrenResolved && (groupByExprs ++ aggregateExpressions).forall(_.resolved) =>
+        tryResolveHavingCondition(h)
+      case h @ UnresolvedHaving(_, g: GroupingSets)
+          if g.childrenResolved && g.expressions.forall(_.resolved) =>
+        tryResolveHavingCondition(h)
+
       case a if !a.childrenResolved => a // be sure all of the children are resolved.
 
       // Ensure group by expressions and aggregate expressions have been resolved.
@@ -964,7 +1031,7 @@ class Analyzer(
       case plan if containsDeserializer(plan.expressions) => plan
 
       // Skip the having clause here, this will be handled in ResolveAggregateFunctions.
-      case h: AggregateWithHaving => h
+      case h: UnresolvedHaving => h
 
       case q: LogicalPlan =>
         logTrace(s"Attempting to resolve ${q.simpleString}")
@@ -1542,7 +1609,7 @@ class Analyzer(
       // Resolve aggregate with having clause to Filter(..., Aggregate()). Note, to avoid wrongly
       // resolve the having condition expression, here we skip resolving it in ResolveReferences
       // and transform it to Filter after aggregate is resolved. See more details in SPARK-31519.
-      case AggregateWithHaving(cond, agg: Aggregate) if agg.resolved =>
+      case UnresolvedHaving(cond, agg: Aggregate) if agg.resolved =>
         resolveHaving(Filter(cond, agg), agg)
 
       case f @ Filter(_, agg: Aggregate) if agg.resolved =>
@@ -1618,13 +1685,13 @@ class Analyzer(
       condition.find(_.isInstanceOf[AggregateExpression]).isDefined
     }
 
-    def resolveHaving(filter: Filter, agg: Aggregate): LogicalPlan = {
-      // Try resolving the condition of the filter as though it is in the aggregate clause
+    def resolveFilterCondInAggregate(
+        filterCond: Expression, agg: Aggregate): Option[(Seq[NamedExpression], Expression)] = {
       try {
         val aggregatedCondition =
           Aggregate(
             agg.groupingExpressions,
-            Alias(filter.condition, "havingCondition")() :: Nil,
+            Alias(filterCond, "havingCondition")() :: Nil,
             agg.child)
         val resolvedOperator = executeSameContext(aggregatedCondition)
         def resolvedAggregateFilter =
@@ -1656,22 +1723,33 @@ class Analyzer(
                   alias.toAttribute
               }
           }
-
-          // Push the aggregate expressions into the aggregate (if any).
           if (aggregateExpressions.nonEmpty) {
-            Project(agg.output,
-              Filter(transformedAggregateFilter,
-                agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions)))
+            Some((aggregateExpressions, transformedAggregateFilter))
           } else {
-            filter
+            None
           }
         } else {
-          filter
+          None
         }
       } catch {
-        // Attempting to resolve in the aggregate can result in ambiguity.  When this happens,
-        // just return the original plan.
-        case ae: AnalysisException => filter
+        // Attempting to resolve in the aggregate can result in ambiguity. When this happens,
+        // just return None and the caller side will return the original plan.
+        case ae: AnalysisException => None
+      }
+    }
+
+    def resolveHaving(filter: Filter, agg: Aggregate): LogicalPlan = {
+      // Try resolving the condition of the filter as though it is in the aggregate clause
+      val resolvedInfo = resolveFilterCondInAggregate(filter.condition, agg)
+
+      // Push the aggregate expressions into the aggregate (if any).
+      if (resolvedInfo.nonEmpty) {
+        val (aggregateExpressions, resolvedHavingCond) = resolvedInfo.get
+        Project(agg.output,
+          Filter(resolvedHavingCond,
+            agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions)))
+      } else {
+        filter
       }
     }
   }
@@ -2064,12 +2142,12 @@ class Analyzer(
       case Filter(condition, _) if hasWindowFunction(condition) =>
         failAnalysis("It is not allowed to use window functions inside WHERE clause")
 
-      case AggregateWithHaving(condition, _) if hasWindowFunction(condition) =>
+      case UnresolvedHaving(condition, _) if hasWindowFunction(condition) =>
         failAnalysis("It is not allowed to use window functions inside HAVING clause")
 
       // Aggregate with Having clause. This rule works with an unresolved Aggregate because
       // a resolved Aggregate will not have Window Functions.
-      case f @ AggregateWithHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
+      case f @ UnresolvedHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
         if child.resolved &&
            hasWindowFunction(aggregateExprs) &&
            a.expressions.forall(_.resolved) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index bcd2ff7..3708ee9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -515,11 +515,12 @@ case class UnresolvedOrdinal(ordinal: Int)
 }
 
 /**
- * Represents unresolved aggregate with having clause, it is turned by the analyzer into a Filter.
+ * Represents unresolved having clause, the child for it can be Aggregate, GroupingSets, Rollup
+ * and Cube. It is turned by the analyzer into a Filter.
  */
-case class AggregateWithHaving(
+case class UnresolvedHaving(
     havingCondition: Expression,
-    child: Aggregate)
+    child: LogicalPlan)
   extends UnaryNode {
   override lazy val resolved: Boolean = false
   override def output: Seq[Attribute] = child.output
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index f7b1638..3dd0af5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -360,7 +360,7 @@ package object dsl {
           groupingExprs: Expression*)(
           aggregateExprs: Expression*)(
           havingCondition: Expression): LogicalPlan = {
-        AggregateWithHaving(havingCondition,
+        UnresolvedHaving(havingCondition,
           groupBy(groupingExprs: _*)(aggregateExprs: _*).asInstanceOf[Aggregate])
       }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 22d5f1d..e2e8a45 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -401,12 +401,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
         case p: Predicate => p
         case e => Cast(e, BooleanType)
       }
-      plan match {
-        case aggregate: Aggregate =>
-          AggregateWithHaving(predicate, aggregate)
-        case _ =>
-          Filter(predicate, plan)
-      }
+      UnresolvedHaving(predicate, plan)
     }
 
 
diff --git a/sql/core/src/test/resources/sql-tests/inputs/having.sql b/sql/core/src/test/resources/sql-tests/inputs/having.sql
index 179686e..ccc633c 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/having.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/having.sql
@@ -18,4 +18,9 @@ SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0);
 SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > 1;
 
 -- SPARK-31519: Datetime functions in having aggregate expressions returns the wrong result
-SELECT SUM(a) AS b, hour('2020-01-01 12:12:12') AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10
+SELECT SUM(a) AS b, hour('2020-01-01 12:12:12') AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10;
+
+-- SPARK-31663: Grouping sets with having clause returns the wrong result
+SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY GROUPING SETS ((b), (a, b)) HAVING b > 10;
+SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY CUBE(a, b) HAVING b > 10;
+SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY ROLLUP(a, b) HAVING b > 10;
diff --git a/sql/core/src/test/resources/sql-tests/results/having.sql.out b/sql/core/src/test/resources/sql-tests/results/having.sql.out
index fd594f5..79d6416 100644
--- a/sql/core/src/test/resources/sql-tests/results/having.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/having.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 6
+-- Number of queries: 9
 
 
 -- !query 0
@@ -55,3 +55,29 @@ SELECT SUM(a) AS b, hour('2020-01-01 12:12:12') AS fake FROM VALUES (1, 10), (2,
 struct<b:bigint,fake:int>
 -- !query 5 output
 2	12
+
+
+-- !query 6
+SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY GROUPING SETS ((b), (a, b)) HAVING b > 10
+-- !query 6 schema
+struct<b:bigint>
+-- !query 6 output
+2
+2
+
+
+-- !query 7
+SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY CUBE(a, b) HAVING b > 10
+-- !query 7 schema
+struct<b:bigint>
+-- !query 7 output
+2
+2
+
+
+-- !query 8
+SELECT SUM(a) AS b FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY ROLLUP(a, b) HAVING b > 10
+-- !query 8 schema
+struct<b:bigint>
+-- !query 8 output
+2


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org