You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/03/09 01:41:37 UTC

[spark] branch branch-3.1 updated: [SPARK-37865][SQL] Fix union deduplication correctness bug

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new 0496173  [SPARK-37865][SQL] Fix union deduplication correctness bug
0496173 is described below

commit 0496173f2597e653a71694fd3b0f24879b62b5d8
Author: Karen Feng <ka...@databricks.com>
AuthorDate: Wed Mar 9 09:34:01 2022 +0800

    [SPARK-37865][SQL] Fix union deduplication correctness bug
    
    Fixes a correctness bug in `Union` in the case that there are duplicate output columns. Previously, duplicate columns on one side of the union would result in a duplicate column being output on the other side of the union.
    
    To do so, we go through the union’s child’s output and find the duplicates. For each duplicate set, there is a first duplicate: this one is left alone. All following duplicates are aliased and given a tag; this tag is used to remove ambiguity during resolution.
    
    As the first duplicate is left alone, the user can still select it, avoiding a breaking change. As the later duplicates are given new expression IDs, this fixes the correctness bug.
    
    Output of union with duplicate columns in the children was incorrect
    
    Example query:
    ```
    SELECT a, a FROM VALUES (1, 1), (1, 2) AS t1(a, b)
    UNION ALL SELECT c, d FROM VALUES (2, 2), (2, 3) AS t2(c, d)
    ```
    
    Result before:
    ```
    a | a
    _ | _
    1 | 1
    1 | 1
    2 | 2
    2 | 2
    ```
    
    Result after:
    ```
    a | a
    _ | _
    1 | 1
    1 | 1
    2 | 2
    2 | 3
    ```
    
    Unit tests
    
    Closes #35760 from karenfeng/spark-37865.
    
    Authored-by: Karen Feng <ka...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit 59ce0a706cb52a54244a747d0a070b61f5cddd1c)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     | 25 +++++++++
 .../spark/sql/catalyst/expressions/package.scala   |  8 ++-
 .../org/apache/spark/sql/DataFrameSuite.scala      | 63 ++++++++++++++++++++++
 3 files changed, 95 insertions(+), 1 deletion(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 5578838..ae5e2d1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1602,6 +1602,31 @@ class Analyzer(override val catalogManager: CatalogManager)
         }
         u.copy(children = newChildren)
 
+      case u @ Union(children, _, _)
+        // if there are duplicate output columns, give them unique expr ids
+          if children.exists(c => c.output.map(_.exprId).distinct.length < c.output.length) =>
+        val newChildren = children.map { c =>
+          if (c.output.map(_.exprId).distinct.length < c.output.length) {
+            val existingExprIds = mutable.HashSet[ExprId]()
+            val projectList = c.output.map { attr =>
+              if (existingExprIds.contains(attr.exprId)) {
+                // replace non-first duplicates with aliases and tag them
+                val newMetadata = new MetadataBuilder().withMetadata(attr.metadata)
+                  .putNull("__is_duplicate").build()
+                Alias(attr, attr.name)(explicitMetadata = Some(newMetadata))
+              } else {
+                // leave first duplicate alone
+                existingExprIds.add(attr.exprId)
+                attr
+              }
+            }
+            Project(projectList, c)
+          } else {
+            c
+          }
+        }
+        u.withNewChildren(newChildren)
+
       // When resolve `SortOrder`s in Sort based on child, don't report errors as
       // we still have chance to resolve it based on its descendants
       case s @ Sort(ordering, global, child) if child.resolved && !s.resolved =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
index d950fef..6a4fb09 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
@@ -335,8 +335,14 @@ package object expressions  {
         matchWithFourOrMoreQualifierParts(nameParts, resolver)
       }
 
+      val prunedCandidates = if (candidates.size > 1) {
+        candidates.filter(c => !c.metadata.contains("__is_duplicate"))
+      } else {
+        candidates
+      }
+
       def name = UnresolvedAttribute(nameParts).name
-      candidates match {
+      prunedCandidates match {
         case Seq(a) if nestedFields.nonEmpty =>
           // One match, but we also need to extract the requested nested field.
           // The foldLeft adds ExtractValues for every remaining parts of the identifier,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index c242bc8..8730aeb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -2697,6 +2697,69 @@ class DataFrameSuite extends QueryTest
       checkAnswer(sql("SELECT sum(c1 * c3) + sum(c2 * c3) FROM tbl"), Row(2.00000000000) :: Nil)
     }
   }
+
+  test("SPARK-37865: Do not deduplicate union output columns") {
+    val df1 = Seq((1, 1), (1, 2)).toDF("a", "b")
+    val df2 = Seq((2, 2), (2, 3)).toDF("c", "d")
+
+    def sqlQuery(cols1: Seq[String], cols2: Seq[String], distinct: Boolean): String = {
+      val union = if (distinct) {
+        "UNION"
+      } else {
+        "UNION ALL"
+      }
+      s"""
+         |SELECT ${cols1.mkString(",")} FROM VALUES (1, 1), (1, 2) AS t1(a, b)
+         |$union SELECT ${cols2.mkString(",")} FROM VALUES (2, 2), (2, 3) AS t2(c, d)
+         |""".stripMargin
+    }
+
+    Seq(
+      (Seq("a", "a"), Seq("c", "d"), Seq(Row(1, 1), Row(1, 1), Row(2, 2), Row(2, 3))),
+      (Seq("a", "b"), Seq("c", "d"), Seq(Row(1, 1), Row(1, 2), Row(2, 2), Row(2, 3))),
+      (Seq("a", "b"), Seq("c", "c"), Seq(Row(1, 1), Row(1, 2), Row(2, 2), Row(2, 2)))
+    ).foreach { case (cols1, cols2, rows) =>
+      // UNION ALL (non-distinct)
+      val df3 = df1.selectExpr(cols1: _*).union(df2.selectExpr(cols2: _*))
+      checkAnswer(df3, rows)
+
+      val t3 = sqlQuery(cols1, cols2, false)
+      checkAnswer(sql(t3), rows)
+
+      // Avoid breaking change
+      var correctAnswer = rows.map(r => Row(r(0)))
+      checkAnswer(df3.select(df1.col("a")), correctAnswer)
+      checkAnswer(sql(s"select a from ($t3) t3"), correctAnswer)
+
+      // This has always been broken
+      intercept[AnalysisException] {
+        df3.select(df2.col("d")).collect()
+      }
+      intercept[AnalysisException] {
+        sql(s"select d from ($t3) t3")
+      }
+
+      // UNION (distinct)
+      val df4 = df3.distinct
+      checkAnswer(df4, rows.distinct)
+
+      val t4 = sqlQuery(cols1, cols2, true)
+      checkAnswer(sql(t4), rows.distinct)
+
+      // Avoid breaking change
+      correctAnswer = rows.distinct.map(r => Row(r(0)))
+      checkAnswer(df4.select(df1.col("a")), correctAnswer)
+      checkAnswer(sql(s"select a from ($t4) t4"), correctAnswer)
+
+      // This has always been broken
+      intercept[AnalysisException] {
+        df4.select(df2.col("d")).collect()
+      }
+      intercept[AnalysisException] {
+        sql(s"select d from ($t4) t4")
+      }
+    }
+  }
 }
 
 case class GroupByKey(a: Int, b: Int)

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org