You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2024/01/23 09:21:24 UTC

(spark) branch branch-3.5 updated: [SPARK-46763] Fix assertion failure in ReplaceDeduplicateWithAggregate for duplicate attributes

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new a559ff7bb9d3 [SPARK-46763] Fix assertion failure in ReplaceDeduplicateWithAggregate for duplicate attributes
a559ff7bb9d3 is described below

commit a559ff7bb9d3c34429f80760741f1bbd40696f32
Author: Nikhil Sheoran <12...@users.noreply.github.com>
AuthorDate: Tue Jan 23 17:15:30 2024 +0800

    [SPARK-46763] Fix assertion failure in ReplaceDeduplicateWithAggregate for duplicate attributes
    
    ### What changes were proposed in this pull request?
    
    - Updated the `ReplaceDeduplicateWithAggregate` implementation to reuse aliases generated for an attribute.
    - Added a unit test to ensure scenarios with duplicate non-grouping keys are correctly optimized.
    
    ### Why are the changes needed?
    
    - `ReplaceDeduplicateWithAggregate` replaces `Deduplicate` with an `Aggregate` operator with grouping expressions for the deduplication keys and aggregate expressions for the non-grouping keys (to preserve the output schema and keep the non-grouping columns).
    - For non-grouping key `a#X`, it generates an aggregate expression of the form `first(a#X, false) AS a#Y`
    - In case the non-grouping keys have a repeated attribute (with the same name and exprId), the existing logic would generate two different aggregate expressions both having two different exprId.
    - This then leads to duplicate rewrite attributes error (in `transformUpWithNewOutput`) when transforming the remaining tree.
    
    - For example, for the query
    ```
    Project [a#0, b#1]
    +- Deduplicate [b#1]
       +- Project [a#0, a#0, b#1]
          +- LocalRelation <empty>, [a#0, b#1]
    ```
    the existing logic would transform it to
    ```
    Project [a#3, b#1]
    +- Aggregate [b#1], [first(a#0, false) AS a#3, first(a#0, false) AS a#5, b#1]
       +- Project [a#0, a#0, b#1]
          +- LocalRelation <empty>, [a#0, b#1]
    ```
    with the aggregate mapping having two entries `a#0 -> a#3, a#0 -> a#5`.
    
    The correct transformation would be
    ```
    Project [a#3, b#1]
    +- Aggregate [b#1], [first(a#0, false) AS a#3, first(a#0, false) AS a#3, b#1]
       +- Project [a#0, a#0, b#1]
          +- LocalRelation <empty>, [a#0, b#1]
    ```
    with the aggregate mapping having only one entry `a#0 -> a#3`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Added a unit test in `ResolveOperatorSuite`.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #44835 from nikhilsheoran-db/SPARK-46763.
    
    Authored-by: Nikhil Sheoran <12...@users.noreply.github.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit 715b43428913d6a631f8f9043baac751b88cb5d4)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../spark/sql/catalyst/optimizer/Optimizer.scala   |  6 ++++-
 .../catalyst/optimizer/ReplaceOperatorSuite.scala  | 31 ++++++++++++++++++++++
 2 files changed, 36 insertions(+), 1 deletion(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index df17840d567e..04d3eb962ed4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -2195,11 +2195,15 @@ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput {
     case d @ Deduplicate(keys, child) if !child.isStreaming =>
       val keyExprIds = keys.map(_.exprId)
+      val generatedAliasesMap = new mutable.HashMap[Attribute, Alias]();
       val aggCols = child.output.map { attr =>
         if (keyExprIds.contains(attr.exprId)) {
           attr
         } else {
-          Alias(new First(attr).toAggregateExpression(), attr.name)()
+          // Keep track of the generated aliases to avoid generating multiple aliases
+          // for the same attribute (in case the attribute is duplicated)
+          generatedAliasesMap.getOrElseUpdate(attr,
+            Alias(new First(attr).toAggregateExpression(), attr.name)())
         }
       }
       // SPARK-22951: Physical aggregate operators distinguishes global aggregation and grouping
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
index 5d81e96a8e58..cb9577e050d0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
@@ -265,4 +265,35 @@ class ReplaceOperatorSuite extends PlanTest {
       Join(basePlan, otherPlan, LeftAnti, Option(condition), JoinHint.NONE)).analyze
     comparePlans(result, correctAnswer)
   }
+
+  test("SPARK-46763: ReplaceDeduplicateWithAggregate non-grouping keys with duplicate attributes") {
+    val a = $"a".int
+    val b = $"b".int
+    val first_a = Alias(new First(a).toAggregateExpression(), a.name)()
+
+    val query = Project(
+      projectList = Seq(a, b),
+      Deduplicate(
+        keys = Seq(b),
+        child = Project(
+          projectList = Seq(a, a, b),
+          child = LocalRelation(Seq(a, b))
+        )
+      )
+    ).analyze
+
+    val result = Optimize.execute(query)
+    val correctAnswer = Project(
+        projectList = Seq(first_a.toAttribute, b),
+        Aggregate(
+            Seq(b),
+            Seq(first_a, first_a, b),
+            Project(
+              projectList = Seq(a, a, b),
+              child = LocalRelation(Seq(a, b))
+            )
+        )
+    ).analyze
+    comparePlans(result, correctAnswer)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org