You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2023/03/28 03:09:12 UTC

[spark] branch branch-3.4 updated: [SPARK-42936][SQL] Fix LCA bug when the having clause can be resolved directly by its child Aggregate

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 46866aa5575 [SPARK-42936][SQL] Fix LCA bug when the having clause can be resolved directly by its child Aggregate
46866aa5575 is described below

commit 46866aa55759307251264cd2626438c467a5ff25
Author: Xinyi Yu <xi...@databricks.com>
AuthorDate: Tue Mar 28 11:08:30 2023 +0800

    [SPARK-42936][SQL] Fix LCA bug when the having clause can be resolved directly by its child Aggregate
    
    ### What changes were proposed in this pull request?
    
    The PR fixes the following bug in LCA + having resolution:
    ```sql
    select sum(value1) as total_1, total_1
    from values(1, 'name', 100, 50) AS data(id, name, value1, value2)
    having total_1 > 0
    
    SparkException: [INTERNAL_ERROR] Found the unresolved operator: 'UnresolvedHaving (total_1#353L > cast(0 as bigint))
    ```
    To trigger the issue, the having condition need to be (can be resolved by) an attribute in the select.
    Without the LCA `total_1`, the query works fine.
    
    #### Root cause of the issue
    `UnresolvedHaving` with `Aggregate` as child can use both the `Aggregate`'s output and the `Aggregate`'s child's output to resolve the having condition. If using the latter, `ResolveReferences` rule will replace the unresolved attribute with a `TempResolvedColumn`.
    For a `UnresolvedHaving` that actually can be resolved directly by its child `Aggregate`, there will be no `TempResolvedColumn` after the rule `ResolveReferences` applies. This  `UnresolvedHaving` still needs to be transformed to `Filter` by rule `ResolveAggregateFunctions`. This rule recognizes the shape: `UnresolvedHaving - Aggregate`.
    However, the current condition (the plan should not contain `TempResolvedColumn`) that prevents LCA rule to apply between `ResolveReferences` and `ResolveAggregateFunctions` does not cover the above case. It can insert `Project` in the middle and break the shape can be matched by `ResolveAggregateFunctions`.
    
    #### Fix
    The PR adds another condition for LCA rule to apply: the plan should not contain any `UnresolvedHaving`.
    
    ### Why are the changes needed?
    
    See above reasoning to fix the bug.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing and added tests.
    
    Closes #40558 from anchovYu/lca-having-bug-fix.
    
    Authored-by: Xinyi Yu <xi...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit 2319a31fbf391c87bf8a1eef8707f46bef006c0f)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../sql/catalyst/analysis/CheckAnalysis.scala      | 58 +++++++++-------------
 .../ResolveLateralColumnAliasReference.scala       | 11 ++--
 .../spark/sql/catalyst/analysis/unresolved.scala   |  1 +
 .../spark/sql/catalyst/trees/TreePatterns.scala    |  1 +
 .../apache/spark/sql/LateralColumnAliasSuite.scala | 39 +++++++++++++--
 5 files changed, 67 insertions(+), 43 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 3af15d2465a..454ad70a509 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -111,6 +111,18 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
     }
   }
 
+  /** Check and throw exception when a given resolved plan contains LateralColumnAliasReference. */
+  private def checkNotContainingLCA(exprSeq: Seq[NamedExpression], plan: LogicalPlan): Unit = {
+    if (!plan.resolved) return
+    exprSeq.foreach(_.transformDownWithPruning(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) {
+      case lcaRef: LateralColumnAliasReference =>
+        throw SparkException.internalError("Resolved plan should not contain any " +
+          s"LateralColumnAliasReference.\nDebugging information: plan:\n$plan",
+          context = lcaRef.origin.getQueryContext,
+          summary = lcaRef.origin.context.summary)
+    })
+  }
+
   private def isMapWithStringKey(e: Expression): Boolean = if (e.resolved) {
     e.dataType match {
       case m: MapType => m.keyType.isInstanceOf[StringType]
@@ -653,16 +665,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
               case UnresolvedWindowExpression(_, windowSpec) =>
                 throw QueryCompilationErrors.windowSpecificationNotDefinedError(windowSpec.name)
             })
-            // This should not happen, resolved Project or Aggregate should restore or resolve
-            // all lateral column alias references. Add check for extra safe.
-            projectList.foreach(_.transformDownWithPruning(
-              _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) {
-              case lcaRef: LateralColumnAliasReference if p.resolved =>
-                throw SparkException.internalError("Resolved Project should not contain " +
-                  s"any LateralColumnAliasReference.\nDebugging information: plan: $p",
-                  context = lcaRef.origin.getQueryContext,
-                  summary = lcaRef.origin.context.summary)
-            })
 
           case j: Join if !j.duplicateResolved =>
             val conflictingAttributes = j.left.outputSet.intersect(j.right.outputSet)
@@ -738,31 +740,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
               messageParameters = Map(
                 "invalidExprSqls" -> invalidExprSqls.mkString(", ")))
 
-          // This should not happen, resolved Project or Aggregate should restore or resolve
-          // all lateral column alias references. Add check for extra safe.
-          case agg @ Aggregate(_, aggList, _)
-            if aggList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) && agg.resolved =>
-            aggList.foreach(_.transformDownWithPruning(
-              _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) {
-              case lcaRef: LateralColumnAliasReference =>
-                throw SparkException.internalError("Resolved Aggregate should not contain " +
-                  s"any LateralColumnAliasReference.\nDebugging information: plan: $agg",
-                  context = lcaRef.origin.getQueryContext,
-                  summary = lcaRef.origin.context.summary)
-            })
-
-          case w @ Window(pList, _, _, _)
-            if pList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) && w.resolved =>
-            pList.foreach(_.transformDownWithPruning(
-              _.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) {
-              case lcaRef: LateralColumnAliasReference =>
-                throw SparkException.internalError(
-                  s"Referencing lateral column alias ${toSQLId(lcaRef.nameParts)} is not " +
-                    s"supported in this Window query case yet. \nDebugging information: plan: $w",
-                  context = lcaRef.origin.getQueryContext,
-                  summary = lcaRef.origin.context.summary)
-            })
-
           case _ => // Analysis successful!
         }
     }
@@ -774,6 +751,17 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
           msg = s"Found the unresolved operator: ${o.simpleString(SQLConf.get.maxToStringFields)}",
           context = o.origin.getQueryContext,
           summary = o.origin.context.summary)
+      // If the plan is resolved, the resolved Project, Aggregate or Window should have restored or
+      // resolved all lateral column alias references. Add check for extra safe.
+      case p @ Project(pList, _)
+        if pList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) =>
+        checkNotContainingLCA(pList, p)
+      case agg @ Aggregate(_, aggList, _)
+        if aggList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) =>
+        checkNotContainingLCA(aggList, agg)
+      case w @ Window(pList, _, _, _)
+        if pList.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) =>
+        checkNotContainingLCA(pList, w)
       case _ =>
     }
   }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala
index 93156465d57..c4e8f241dfb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.WindowExpression.hasWindowExpre
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, TEMP_RESOLVED_COLUMN}
+import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, TEMP_RESOLVED_COLUMN, UNRESOLVED_HAVING}
 import org.apache.spark.sql.catalyst.util.toPrettySQL
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.internal.SQLConf
@@ -134,10 +134,11 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] {
   override def apply(plan: LogicalPlan): LogicalPlan = {
     if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) {
       plan
-    } else if (plan.containsPattern(TEMP_RESOLVED_COLUMN)) {
-      // We should not change the plan if `TempResolvedColumn` is present in the query plan. It
-      // needs certain plan shape to get resolved, such as Filter/Sort + Aggregate. LCA resolution
-      // may break the plan shape, like adding Project above Aggregate.
+    } else if (plan.containsAnyPattern(TEMP_RESOLVED_COLUMN, UNRESOLVED_HAVING)) {
+      // It should not change the plan if `TempResolvedColumn` or `UnresolvedHaving` is present in
+      // the query plan. These plans need certain plan shape to get recognized and resolved by other
+      // rules, such as Filter/Sort + Aggregate to be matched by ResolveAggregateFunctions.
+      // LCA resolution can break the plan shape, like adding Project above Aggregate.
       plan
     } else {
       // phase 2: unwrap
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 04bc513aa99..ff002e9149a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -691,6 +691,7 @@ case class UnresolvedHaving(
   override def output: Seq[Attribute] = child.output
   override protected def withNewChildInternal(newChild: LogicalPlan): UnresolvedHaving =
     copy(child = newChild)
+  final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_HAVING)
 }
 
 /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index ce853b5773c..e5bf12e8472 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -133,6 +133,7 @@ object TreePattern extends Enumeration  {
   val UNRESOLVED_ALIAS: Value = Value
   val UNRESOLVED_ATTRIBUTE: Value = Value
   val UNRESOLVED_DESERIALIZER: Value = Value
+  val UNRESOLVED_HAVING: Value = Value
   val UNRESOLVED_ORDINAL: Value = Value
   val UNRESOLVED_FUNCTION: Value = Value
   val UNRESOLVED_HINT: Value = Value
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala
index 6c0974d536f..5a7720db4d3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala
@@ -634,19 +634,52 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase {
       lca = "`bar`", windowExprRegex = "\"RANK.*\"")
   }
 
-  test("Lateral alias reference attribute further be used by upper plan") {
-    // underlying this is not in the scope of lateral alias project but things already supported
+  test("Lateral alias reference works with having and order by") {
+    // order by is resolved by an attribute in project / aggregate
+    // this is not in the scope of lateral alias feature but things already supported
     checkAnswer(
       sql(s"SELECT properties AS new_properties, new_properties.joinYear AS new_join_year " +
         s"FROM $testTable WHERE dept = 1 ORDER BY new_join_year DESC"),
       Row(Row(2020, "B"), 2020) :: Row(Row(2019, "A"), 2019) :: Nil
     )
-
     checkAnswer(
       sql(s"SELECT avg(bonus) AS avg_bonus, avg_bonus * 1.0 AS new_avg_bonus, avg(salary) " +
         s"FROM $testTable GROUP BY dept ORDER BY new_avg_bonus"),
       Row(1100, 1100, 9500.0) :: Row(1200, 1200, 12000) :: Row(1250, 1250, 11000) :: Nil
     )
+    checkAnswer(
+      sql(s"SELECT avg(bonus) AS dept, dept, avg(salary) AS a, a + 10 AS b " +
+        s"FROM $testTable GROUP BY dept ORDER BY dept"),
+      Row(1100, 1, 9500, 9510) :: Row(1250, 2, 11000, 11010) :: Row(1200, 6, 12000, 12010) :: Nil
+    )
+    // order by is resolved by aggregate's child
+    checkAnswer(
+      sql(s"SELECT avg(bonus) AS dept, dept, avg(salary) AS a, a + 10 AS b " +
+        s"FROM $testTable GROUP BY dept ORDER BY max(name)"),
+      Row(1100, 1, 9500, 9510) :: Row(1250, 2, 11000, 11010) :: Row(1200, 6, 12000, 12010) :: Nil
+    )
+    checkAnswer(
+      sql(s"SELECT avg(bonus) AS dept, dept, avg(salary) AS a, a " + // no extra calculation
+        s"FROM $testTable GROUP BY dept ORDER BY dept"),
+      Row(1100, 1, 9500, 9500) :: Row(1250, 2, 11000, 11000) :: Row(1200, 6, 12000, 12000) :: Nil
+    )
+    checkAnswer(
+      sql(s"SELECT dept as a, a " + // even no extra function resolution
+        s"FROM $testTable GROUP BY dept ORDER BY max(name)"),
+      Row(1, 1) :: Row(2, 2) :: Row(6, 6) :: Nil
+    )
+
+    // having cond is resolved by aggregate's child
+    checkAnswer(
+      sql(s"SELECT avg(bonus) AS dept, dept, avg(salary) AS a, a + 10 AS b " +
+        s"FROM $testTable GROUP BY dept HAVING max(name) = 'david'"),
+      Row(1250, 2, 11000, 11010) :: Nil
+    )
+    // having cond is resolved by aggregate itself
+    checkAnswer(
+      sql(s"SELECT avg(bonus) AS a, a FROM $testTable GROUP BY dept HAVING a > 1200"),
+      Row(1250, 1250) :: Nil
+    )
   }
 
   test("Lateral alias chaining") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org