You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2020/07/23 14:30:52 UTC

[spark] branch branch-3.0 updated: [SPARK-32280][SPARK-32372][SQL] ResolveReferences.dedupRight should only rewrite attributes for ancestor nodes of the conflict plan

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new ebac47b  [SPARK-32280][SPARK-32372][SQL] ResolveReferences.dedupRight should only rewrite attributes for ancestor nodes of the conflict plan
ebac47b is described below

commit ebac47b96cb38d7cd5d73a3f238017a5f8d77d1a
Author: yi.wu <yi...@databricks.com>
AuthorDate: Thu Jul 23 14:24:47 2020 +0000

    [SPARK-32280][SPARK-32372][SQL] ResolveReferences.dedupRight should only rewrite attributes for ancestor nodes of the conflict plan
    
    This PR refactors `ResolveReferences.dedupRight` to make sure it only rewrite attributes for ancestor nodes of the conflict plan.
    
    This is a bug fix.
    
    ```scala
    sql("SELECT name, avg(age) as avg_age FROM person GROUP BY name")
      .createOrReplaceTempView("person_a")
    sql("SELECT p1.name, p2.avg_age FROM person p1 JOIN person_a p2 ON p1.name = p2.name")
      .createOrReplaceTempView("person_b")
    sql("SELECT * FROM person_a UNION SELECT * FROM person_b")
      .createOrReplaceTempView("person_c")
    sql("SELECT p1.name, p2.avg_age FROM person_c p1 JOIN person_c p2 ON p1.name = p2.name").show()
    ```
    When executing the above query, we'll hit the error:
    
    ```scala
    [info]   Failed to analyze query: org.apache.spark.sql.AnalysisException: Resolved attribute(s) avg_age#231 missing from name#223,avg_age#218,id#232,age#234,name#233 in operator !Project [name#233, avg_age#231]. Attribute(s) with the same name appear in the operation: avg_age. Please check if the right attribute(s) are used.;;
    ...
    ```
    
    The plan below is the problematic plan which is the right plan of a `Join` operator. And, it has conflict plans comparing to the left plan. In this problematic plan, the first `Aggregate` operator (the one under the first child of `Union`) becomes a conflict plan compares to the left one and has a rewrite attribute pair as  `avg_age#218` -> `avg_age#231`. With the current `dedupRight` logic, we'll first replace this `Aggregate` with a new one, and then rewrites the attribute `avg_age# [...]
    
    ```scala
    :

    :
    +- SubqueryAlias p2
       +- SubqueryAlias person_c
          +- Distinct
             +- Union
                :- Project [name#233, avg_age#231]
                :  +- SubqueryAlias person_a
                :     +- Aggregate [name#233], [name#233, avg(cast(age#234 as bigint)) AS avg_age#231]
                :        +- SubqueryAlias person
                :           +- SerializeFromObject [knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$Person, true])).id AS id#232, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$Person, true])).name, true, false) AS name#233, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$Person, true])).age AS age#234]
                :              +- ExternalRDD [obj#165]
                +- Project [name#233 AS name#227, avg_age#231 AS avg_age#228]
                   +- Project [name#233, avg_age#231]
                      +- SubqueryAlias person_b
                         +- !Project [name#233, avg_age#231]
                            +- Join Inner, (name#233 = name#223)
                               :- SubqueryAlias p1
                               :  +- SubqueryAlias person
                               :     +- SerializeFromObject [knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$Person, true])).id AS id#232, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$Person, true])).name, true, false) AS name#233, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$Person, true])).age AS age#234]
                               :        +- ExternalRDD [obj#165]
                               +- SubqueryAlias p2
                                  +- SubqueryAlias person_a
                                     +- Aggregate [name#223], [name#223, avg(cast(age#224 as bigint)) AS avg_age#218]
                                        +- SubqueryAlias person
                                           +- SerializeFromObject [knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$Person, true])).id AS id#222, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$Person, true])).name, true, false) AS name#223, knownnotnull(assertnotnull(input[0, org.apache.spark.sql.test.SQLTestData$Person, true])).age AS age#224]
                                              +- ExternalRDD [obj#165]
    ```
    
    Yes, users would no longer hit the error after this fix.
    
    Added test.
    
    Closes #29166 from Ngone51/impr-dedup.
    
    Authored-by: yi.wu <yi...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit a8e3de36e7d543f1c7923886628ac3178f45f512)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     | 63 ++++++++++++++++++----
 .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 22 ++++++++
 2 files changed, 75 insertions(+), 10 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 68fe580..bd5a797 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1181,11 +1181,24 @@ class Analyzer(
             if findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
           Seq((oldVersion, oldVersion.copy(projectList = newAliases(projectList))))
 
+        // We don't need to search child plan recursively if the projectList of a Project
+        // is only composed of Alias and doesn't contain any conflicting attributes.
+        // Because, even if the child plan has some conflicting attributes, the attributes
+        // will be aliased to non-conflicting attributes by the Project at the end.
+        case _ @ Project(projectList, _)
+          if findAliases(projectList).size == projectList.size =>
+          Nil
+
         case oldVersion @ Aggregate(_, aggregateExpressions, _)
             if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
           Seq((oldVersion, oldVersion.copy(
             aggregateExpressions = newAliases(aggregateExpressions))))
 
+        // We don't search the child plan recursively for the same reason as the above Project.
+        case _ @ Aggregate(_, aggregateExpressions, _)
+          if findAliases(aggregateExpressions).size == aggregateExpressions.size =>
+          Nil
+
         case oldVersion @ FlatMapGroupsInPandas(_, _, output, _)
             if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
           Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
@@ -1226,20 +1239,50 @@ class Analyzer(
       if (conflictPlans.isEmpty) {
         right
       } else {
-        val attributeRewrites = AttributeMap(conflictPlans.flatMap {
-          case (oldRelation, newRelation) => oldRelation.output.zip(newRelation.output)})
-        val conflictPlanMap = conflictPlans.toMap
-        // transformDown so that we can replace all the old Relations in one turn due to
-        // the reason that `conflictPlans` are also collected in pre-order.
-        right transformDown {
-          case r => conflictPlanMap.getOrElse(r, r)
-        } transformUp {
-          case other => other transformExpressions {
+        rewritePlan(right, conflictPlans.toMap)._1
+      }
+    }
+
+    private def rewritePlan(plan: LogicalPlan, conflictPlanMap: Map[LogicalPlan, LogicalPlan])
+      : (LogicalPlan, Seq[(Attribute, Attribute)]) = {
+      if (conflictPlanMap.contains(plan)) {
+        // If the plan is the one that conflict the with left one, we'd
+        // just replace it with the new plan and collect the rewrite
+        // attributes for the parent node.
+        val newRelation = conflictPlanMap(plan)
+        newRelation -> plan.output.zip(newRelation.output)
+      } else {
+        val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]()
+        val newPlan = plan.mapChildren { child =>
+          // If not, we'd rewrite child plan recursively until we find the
+          // conflict node or reach the leaf node.
+          val (newChild, childAttrMapping) = rewritePlan(child, conflictPlanMap)
+          attrMapping ++= childAttrMapping.filter { case (oldAttr, _) =>
+            // `attrMapping` is not only used to replace the attributes of the current `plan`,
+            // but also to be propagated to the parent plans of the current `plan`. Therefore,
+            // the `oldAttr` must be part of either `plan.references` (so that it can be used to
+            // replace attributes of the current `plan`) or `plan.outputSet` (so that it can be
+            // used by those parent plans).
+            (plan.outputSet ++ plan.references).contains(oldAttr)
+          }
+          newChild
+        }
+
+        if (attrMapping.isEmpty) {
+          newPlan -> attrMapping
+        } else {
+          assert(!attrMapping.groupBy(_._1.exprId)
+            .exists(_._2.map(_._2.exprId).distinct.length > 1),
+            "Found duplicate rewrite attributes")
+          val attributeRewrites = AttributeMap(attrMapping)
+          // Using attrMapping from the children plans to rewrite their parent node.
+          // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
+          newPlan.transformExpressions {
             case a: Attribute =>
               dedupAttr(a, attributeRewrites)
             case s: SubqueryExpression =>
               s.withNewPlan(dedupOuterReferencesInSubquery(s.plan, attributeRewrites))
-          }
+          } -> attrMapping
         }
       }
     }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 093f2db..6fab47d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -3467,6 +3467,28 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
             |""".stripMargin), Row(1))
     }
   }
+
+  test("SPARK-32372: ResolveReferences.dedupRight should only rewrite attributes for ancestor " +
+    "plans of the conflict plan") {
+    sql("SELECT name, avg(age) as avg_age FROM person GROUP BY name")
+      .createOrReplaceTempView("person_a")
+    sql("SELECT p1.name, p2.avg_age FROM person p1 JOIN person_a p2 ON p1.name = p2.name")
+      .createOrReplaceTempView("person_b")
+    sql("SELECT * FROM person_a UNION SELECT * FROM person_b")
+      .createOrReplaceTempView("person_c")
+    checkAnswer(
+      sql("SELECT p1.name, p2.avg_age FROM person_c p1 JOIN person_c p2 ON p1.name = p2.name"),
+      Row("jim", 20.0) :: Row("mike", 30.0) :: Nil)
+  }
+
+  test("SPARK-32280: Avoid duplicate rewrite attributes when there're multiple JOINs") {
+    sql("SELECT 1 AS id").createOrReplaceTempView("A")
+    sql("SELECT id, 'foo' AS kind FROM A").createOrReplaceTempView("B")
+    sql("SELECT l.id as id FROM B AS l LEFT SEMI JOIN B AS r ON l.kind = r.kind")
+      .createOrReplaceTempView("C")
+    checkAnswer(sql("SELECT 0 FROM ( SELECT * FROM B JOIN C USING (id)) " +
+      "JOIN ( SELECT * FROM B JOIN C USING (id)) USING (id)"), Row(0))
+  }
 }
 
 case class Foo(bar: Option[String])


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org