You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2020/09/18 17:33:40 UTC

[spark] branch branch-2.4 updated: [SPARK-32635][SQL][2.4] Fix foldable propagation

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new 62708db  [SPARK-32635][SQL][2.4] Fix foldable propagation
62708db is described below

commit 62708db4f90f652cd9bc73998ac5f1e949bd41ac
Author: Peter Toth <pe...@gmail.com>
AuthorDate: Fri Sep 18 10:28:30 2020 -0700

    [SPARK-32635][SQL][2.4] Fix foldable propagation
    
    ### What changes were proposed in this pull request?
    This PR rewrites `FoldablePropagation` rule to replace attribute references in a node with foldables coming only from the node's children.
    
    Before this PR in the case of this example (with setting`spark.sql.optimizer.excludedRules=org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation`):
    ```scala
    val a = Seq("1").toDF("col1").withColumn("col2", lit("1"))
    val b = Seq("2").toDF("col1").withColumn("col2", lit("2"))
    val aub = a.union(b)
    val c = aub.filter($"col1" === "2").cache()
    val d = Seq("2").toDF( "col4")
    val r = d.join(aub, $"col2" === $"col4").select("col4")
    val l = c.select("col2")
    val df = l.join(r, $"col2" === $"col4", "LeftOuter")
    df.show()
    ```
    foldable propagation happens incorrectly:
    ```
     Join LeftOuter, (col2#6 = col4#34)                                                              Join LeftOuter, (col2#6 = col4#34)
    !:- Project [col2#6]                                                                             :- Project [1 AS col2#6]
     :  +- InMemoryRelation [col1#4, col2#6], StorageLevel(disk, memory, deserialized, 1 replicas)   :  +- InMemoryRelation [col1#4, col2#6], StorageLevel(disk, memory, deserialized, 1 replicas)
     :        +- Union                                                                               :        +- Union
     :           :- *(1) Project [value#1 AS col1#4, 1 AS col2#6]                                    :           :- *(1) Project [value#1 AS col1#4, 1 AS col2#6]
     :           :  +- *(1) Filter (isnotnull(value#1) AND (value#1 = 2))                            :           :  +- *(1) Filter (isnotnull(value#1) AND (value#1 = 2))
     :           :     +- *(1) LocalTableScan [value#1]                                              :           :     +- *(1) LocalTableScan [value#1]
     :           +- *(2) Project [value#10 AS col1#13, 2 AS col2#15]                                 :           +- *(2) Project [value#10 AS col1#13, 2 AS col2#15]
     :              +- *(2) Filter (isnotnull(value#10) AND (value#10 = 2))                          :              +- *(2) Filter (isnotnull(value#10) AND (value#10 = 2))
     :                 +- *(2) LocalTableScan [value#10]                                             :                 +- *(2) LocalTableScan [value#10]
     +- Project [col4#34]                                                                            +- Project [col4#34]
        +- Join Inner, (col2#6 = col4#34)                                                               +- Join Inner, (col2#6 = col4#34)
           :- Project [value#31 AS col4#34]                                                                :- Project [value#31 AS col4#34]
           :  +- LocalRelation [value#31]                                                                  :  +- LocalRelation [value#31]
           +- Project [col2#6]                                                                             +- Project [col2#6]
              +- Union false, false                                                                           +- Union false, false
                 :- Project [1 AS col2#6]                                                                        :- Project [1 AS col2#6]
                 :  +- LocalRelation [value#1]                                                                   :  +- LocalRelation [value#1]
                 +- Project [2 AS col2#15]                                                                       +- Project [2 AS col2#15]
                    +- LocalRelation [value#10]                                                                     +- LocalRelation [value#10]
    
    ```
    and so the result is wrong:
    ```
    +----+----+
    |col2|col4|
    +----+----+
    |   1|null|
    +----+----+
    ```
    
    After this PR foldable propagation will not happen incorrectly and the result is correct:
    ```
    +----+----+
    |col2|col4|
    +----+----+
    |   2|   2|
    +----+----+
    ```
    
    ### Why are the changes needed?
    To fix a correctness issue.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, fixes a correctness issue.
    
    ### How was this patch tested?
    Existing and new UTs.
    
    Closes #29805 from peter-toth/SPARK-32635-fix-foldable-propagation-2.4.
    
    Authored-by: Peter Toth <pe...@gmail.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../sql/catalyst/expressions/AttributeMap.scala    |   2 +
 .../spark/sql/catalyst/optimizer/expressions.scala | 121 ++++++++++++---------
 .../optimizer/FoldablePropagationSuite.scala       |  12 ++
 .../org/apache/spark/sql/DataFrameSuite.scala      |  12 ++
 4 files changed, 98 insertions(+), 49 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
index 9f4a0f2..1e8f8ca 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
@@ -26,6 +26,8 @@ object AttributeMap {
   def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = {
     new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap)
   }
+
+  def empty[A]: AttributeMap[A] = new AttributeMap(Map.empty)
 }
 
 class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)])
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index be0e702..6c6f6313 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -595,59 +595,82 @@ object NullPropagation extends Rule[LogicalPlan] {
  */
 object FoldablePropagation extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = {
-    var foldableMap = AttributeMap(plan.flatMap {
-      case Project(projectList, _) => projectList.collect {
-        case a: Alias if a.child.foldable => (a.toAttribute, a)
-      }
-      case _ => Nil
-    })
-    val replaceFoldable: PartialFunction[Expression, Expression] = {
-      case a: AttributeReference if foldableMap.contains(a) => foldableMap(a)
+    CleanupAliases(propagateFoldables(plan)._1)
+  }
+
+  private def propagateFoldables(plan: LogicalPlan): (LogicalPlan, AttributeMap[Alias]) = {
+    plan match {
+      case p: Project =>
+        val (newChild, foldableMap) = propagateFoldables(p.child)
+        val newProject =
+          replaceFoldable(p.withNewChildren(Seq(newChild)).asInstanceOf[Project], foldableMap)
+        val newFoldableMap = AttributeMap(newProject.projectList.collect {
+          case a: Alias if a.child.foldable => (a.toAttribute, a)
+        })
+        (newProject, newFoldableMap)
+
+      // We can not replace the attributes in `Expand.output`. If there are other non-leaf
+      // operators that have the `output` field, we should put them here too.
+      case e: Expand =>
+        val (newChild, foldableMap) = propagateFoldables(e.child)
+        val expandWithNewChildren = e.withNewChildren(Seq(newChild)).asInstanceOf[Expand]
+        val newExpand = if (foldableMap.isEmpty) {
+          expandWithNewChildren
+        } else {
+          val newProjections = expandWithNewChildren.projections.map(_.map(_.transform {
+            case a: AttributeReference if foldableMap.contains(a) => foldableMap(a)
+          }))
+          if (newProjections == expandWithNewChildren.projections) {
+            expandWithNewChildren
+          } else {
+            expandWithNewChildren.copy(projections = newProjections)
+          }
+        }
+        (newExpand, foldableMap)
+
+      case u: UnaryNode if canPropagateFoldables(u) =>
+        val (newChild, foldableMap) = propagateFoldables(u.child)
+        val newU = replaceFoldable(u.withNewChildren(Seq(newChild)), foldableMap)
+        (newU, foldableMap)
+
+      // Join derives the output attributes from its child while they are actually not the
+      // same attributes. For example, the output of outer join is not always picked from its
+      // children, but can also be null. We should exclude these miss-derived attributes when
+      // propagating the foldable expressions.
+      // TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes
+      // of outer join.
+      case j: Join =>
+        val (newChildren, foldableMaps) = j.children.map(propagateFoldables).unzip
+        val foldableMap = AttributeMap(
+          foldableMaps.foldLeft(Iterable.empty[(Attribute, Alias)])(_ ++ _.baseMap.values).toSeq)
+        val newJoin =
+          replaceFoldable(j.withNewChildren(newChildren).asInstanceOf[Join], foldableMap)
+        val missDerivedAttrsSet: AttributeSet = AttributeSet(newJoin.joinType match {
+          case _: InnerLike | LeftExistence(_) => Nil
+          case LeftOuter => newJoin.right.output
+          case RightOuter => newJoin.left.output
+          case FullOuter => newJoin.left.output ++ newJoin.right.output
+        })
+        val newFoldableMap = AttributeMap(foldableMap.baseMap.values.filterNot {
+          case (attr, _) => missDerivedAttrsSet.contains(attr)
+        }.toSeq)
+        (newJoin, newFoldableMap)
+
+      // For other plans, they are not safe to apply foldable propagation, and they should not
+      // propagate foldable expressions from children.
+      case o =>
+        val newOther = o.mapChildren(propagateFoldables(_)._1)
+        (newOther, AttributeMap.empty)
     }
+  }
 
+  private def replaceFoldable(plan: LogicalPlan, foldableMap: AttributeMap[Alias]): plan.type = {
     if (foldableMap.isEmpty) {
       plan
     } else {
-      CleanupAliases(plan.transformUp {
-        // We can only propagate foldables for a subset of unary nodes.
-        case u: UnaryNode if foldableMap.nonEmpty && canPropagateFoldables(u) =>
-          u.transformExpressions(replaceFoldable)
-
-        // Join derives the output attributes from its child while they are actually not the
-        // same attributes. For example, the output of outer join is not always picked from its
-        // children, but can also be null. We should exclude these miss-derived attributes when
-        // propagating the foldable expressions.
-        // TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes
-        // of outer join.
-        case j @ Join(left, right, joinType, _) if foldableMap.nonEmpty =>
-          val newJoin = j.transformExpressions(replaceFoldable)
-          val missDerivedAttrsSet: AttributeSet = AttributeSet(joinType match {
-            case _: InnerLike | LeftExistence(_) => Nil
-            case LeftOuter => right.output
-            case RightOuter => left.output
-            case FullOuter => left.output ++ right.output
-          })
-          foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot {
-            case (attr, _) => missDerivedAttrsSet.contains(attr)
-          }.toSeq)
-          newJoin
-
-        // We can not replace the attributes in `Expand.output`. If there are other non-leaf
-        // operators that have the `output` field, we should put them here too.
-        case expand: Expand if foldableMap.nonEmpty =>
-          expand.copy(projections = expand.projections.map { projection =>
-            projection.map(_.transform(replaceFoldable))
-          })
-
-        // For other plans, they are not safe to apply foldable propagation, and they should not
-        // propagate foldable expressions from children.
-        case other if foldableMap.nonEmpty =>
-          val childrenOutputSet = AttributeSet(other.children.flatMap(_.output))
-          foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot {
-            case (attr, _) => childrenOutputSet.contains(attr)
-          }.toSeq)
-          other
-      })
+      plan transformExpressions {
+        case a: AttributeReference if foldableMap.contains(a) => foldableMap(a)
+      }
     }
   }
 
@@ -655,7 +678,7 @@ object FoldablePropagation extends Rule[LogicalPlan] {
    * Whitelist of all [[UnaryNode]]s for which allow foldable propagation.
    */
   private def canPropagateFoldables(u: UnaryNode): Boolean = u match {
-    case _: Project => true
+    // Handling `Project` is moved to `propagateFoldables`.
     case _: Filter => true
     case _: SubqueryAlias => true
     case _: Aggregate => true
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala
index c288446..2c45199 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala
@@ -180,4 +180,16 @@ class FoldablePropagationSuite extends PlanTest {
       .select((Literal(1) + 3).as('res)).analyze
     comparePlans(optimized, correctAnswer)
   }
+
+  test("SPARK-32635: Replace references with foldables coming only from the node's children") {
+    val leftExpression = 'a.int
+    val left = LocalRelation(leftExpression).select('a)
+    val rightExpression = Alias(Literal(2), "a")(leftExpression.exprId)
+    val right = LocalRelation('b.int).select('b, rightExpression).select('b)
+    val join = left.join(right, joinType = LeftOuter, condition = Some('b === 'a))
+
+    val query = join.analyze
+    val optimized = Optimize.execute(query)
+    comparePlans(optimized, query)
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index e7d55ee..037cf23 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -2720,6 +2720,18 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
     val nestedDecArray = Array(decSpark)
     checkAnswer(Seq(nestedDecArray).toDF(), Row(Array(wrapRefArray(decJava))))
   }
+
+  test("SPARK-32635: Replace references with foldables coming only from the node's children") {
+    val a = Seq("1").toDF("col1").withColumn("col2", lit("1"))
+    val b = Seq("2").toDF("col1").withColumn("col2", lit("2"))
+    val aub = a.union(b)
+    val c = aub.filter($"col1" === "2").cache()
+    val d = Seq("2").toDF("col4")
+    val r = d.join(aub, $"col2" === $"col4").select("col4")
+    val l = c.select("col2")
+    val df = l.join(r, $"col2" === $"col4", "LeftOuter")
+    checkAnswer(df, Row("2", "2"))
+  }
 }
 
 case class GroupByKey(a: Int, b: Int)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org