You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2019/11/14 06:48:40 UTC

[spark] branch branch-2.4 updated: [SPARK-29682][SQL] Resolve conflicting attributes in Expand correctly

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new 7bdc76f  [SPARK-29682][SQL] Resolve conflicting attributes in Expand correctly
7bdc76f is described below

commit 7bdc76f9c1f062aecdaf361527adbd9f50518bc9
Author: Terry Kim <yu...@gmail.com>
AuthorDate: Thu Nov 14 14:47:14 2019 +0800

    [SPARK-29682][SQL] Resolve conflicting attributes in Expand correctly
    
    ### What changes were proposed in this pull request?
    
    This PR addresses issues where conflicting attributes in `Expand` are not correctly handled.
    
    ### Why are the changes needed?
    
    ```Scala
    val numsDF = Seq(1, 2, 3, 4, 5, 6).toDF("nums")
    val cubeDF = numsDF.cube("nums").agg(max(lit(0)).as("agcol"))
    cubeDF.join(cubeDF, "nums").show
    ```
    fails with the following exception:
    ```
    org.apache.spark.sql.AnalysisException:
    Failure when resolving conflicting references in Join:
    'Join Inner
    :- Aggregate [nums#38, spark_grouping_id#36], [nums#38, max(0) AS agcol#35]
    :  +- Expand [List(nums#3, nums#37, 0), List(nums#3, null, 1)], [nums#3, nums#38, spark_grouping_id#36]
    :     +- Project [nums#3, nums#3 AS nums#37]
    :        +- Project [value#1 AS nums#3]
    :           +- LocalRelation [value#1]
    +- Aggregate [nums#38, spark_grouping_id#36], [nums#38, max(0) AS agcol#58]
       +- Expand [List(nums#3, nums#37, 0), List(nums#3, null, 1)], [nums#3, nums#38, spark_grouping_id#36]
                                                                             ^^^^^^^
          +- Project [nums#3, nums#3 AS nums#37]
             +- Project [value#1 AS nums#3]
                +- LocalRelation [value#1]
    
    Conflicting attributes: nums#38
    ```
    As you can see from the above plan, `num#38`, the output of `Expand` on the right side of `Join`, should have been handled to produce new attribute. Since the conflict is not resolved in `Expand`, the failure is happening upstream at `Aggregate`. This PR addresses handling conflicting attributes in `Expand`.
    
    ### Does this PR introduce any user-facing change?
    
    Yes, the previous example now shows the following output:
    ```
    +----+-----+-----+
    |nums|agcol|agcol|
    +----+-----+-----+
    |   1|    0|    0|
    |   6|    0|    0|
    |   4|    0|    0|
    |   2|    0|    0|
    |   5|    0|    0|
    |   3|    0|    0|
    +----+-----+-----+
    ```
    ### How was this patch tested?
    
    Added new unit test.
    
    Closes #26441 from imback82/spark-29682.
    
    Authored-by: Terry Kim <yu...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit e46e487b0831b39afa12ef9cff9b9133f111921b)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala    | 12 ++++++++++++
 .../sql/catalyst/plans/logical/basicLogicalOperators.scala   |  2 ++
 .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala  |  9 +++++++++
 3 files changed, 23 insertions(+)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index bed742f..f8b9513 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -792,6 +792,18 @@ class Analyzer(
           val newOutput = oldVersion.generatorOutput.map(_.newInstance())
           (oldVersion, oldVersion.copy(generatorOutput = newOutput))
 
+        case oldVersion: Expand
+            if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
+          val producedAttributes = oldVersion.producedAttributes
+          val newOutput = oldVersion.output.map { attr =>
+            if (producedAttributes.contains(attr)) {
+              attr.newInstance()
+            } else {
+              attr
+            }
+          }
+          (oldVersion, oldVersion.copy(output = newOutput))
+
         case oldVersion @ Window(windowExpressions, _, _, child)
             if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
               .nonEmpty =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 84b4a4d..fd6ebff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -708,6 +708,8 @@ case class Expand(
   override def references: AttributeSet =
     AttributeSet(projections.flatten.flatMap(_.references))
 
+  override def producedAttributes: AttributeSet = AttributeSet(output diff child.output)
+
   // This operator can reuse attributes (for example making them null when doing a roll up) so
   // the constraints of the child may no longer be valid.
   override protected def validConstraints: Set[Expression] = Set.empty[Expression]
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 38fb5d8..3ad7334 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -3060,6 +3060,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
         """.stripMargin).collect()
     }
   }
+
+  test("SPARK-29682: Conflicting attributes in Expand are resolved") {
+    val numsDF = Seq(1, 2, 3).toDF("nums")
+    val cubeDF = numsDF.cube("nums").agg(max(lit(0)).as("agcol"))
+
+    checkAnswer(
+      cubeDF.join(cubeDF, "nums"),
+      Row(1, 0, 0) :: Row(2, 0, 0) :: Row(3, 0, 0) :: Nil)
+  }
 }
 
 case class Foo(bar: Option[String])


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org