You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ya...@apache.org on 2020/09/17 23:20:36 UTC

[spark] branch master updated: [SPARK-32635][SQL] Fix foldable propagation

This is an automated email from the ASF dual-hosted git repository.

yamamuro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4ced588  [SPARK-32635][SQL] Fix foldable propagation
4ced588 is described below

commit 4ced58862c707aa916f7a55d15c3887c94c9b210
Author: Peter Toth <pe...@gmail.com>
AuthorDate: Fri Sep 18 08:17:23 2020 +0900

    [SPARK-32635][SQL] Fix foldable propagation
    
    ### What changes were proposed in this pull request?
    This PR rewrites `FoldablePropagation` rule to replace attribute references in a node with foldables coming only from the node's children.
    
    Before this PR in the case of this example (with setting`spark.sql.optimizer.excludedRules=org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation`):
    ```scala
    val a = Seq("1").toDF("col1").withColumn("col2", lit("1"))
    val b = Seq("2").toDF("col1").withColumn("col2", lit("2"))
    val aub = a.union(b)
    val c = aub.filter($"col1" === "2").cache()
    val d = Seq("2").toDF( "col4")
    val r = d.join(aub, $"col2" === $"col4").select("col4")
    val l = c.select("col2")
    val df = l.join(r, $"col2" === $"col4", "LeftOuter")
    df.show()
    ```
    foldable propagation happens incorrectly:
    ```
     Join LeftOuter, (col2#6 = col4#34)                                                              Join LeftOuter, (col2#6 = col4#34)
    !:- Project [col2#6]                                                                             :- Project [1 AS col2#6]
     :  +- InMemoryRelation [col1#4, col2#6], StorageLevel(disk, memory, deserialized, 1 replicas)   :  +- InMemoryRelation [col1#4, col2#6], StorageLevel(disk, memory, deserialized, 1 replicas)
     :        +- Union                                                                               :        +- Union
     :           :- *(1) Project [value#1 AS col1#4, 1 AS col2#6]                                    :           :- *(1) Project [value#1 AS col1#4, 1 AS col2#6]
     :           :  +- *(1) Filter (isnotnull(value#1) AND (value#1 = 2))                            :           :  +- *(1) Filter (isnotnull(value#1) AND (value#1 = 2))
     :           :     +- *(1) LocalTableScan [value#1]                                              :           :     +- *(1) LocalTableScan [value#1]
     :           +- *(2) Project [value#10 AS col1#13, 2 AS col2#15]                                 :           +- *(2) Project [value#10 AS col1#13, 2 AS col2#15]
     :              +- *(2) Filter (isnotnull(value#10) AND (value#10 = 2))                          :              +- *(2) Filter (isnotnull(value#10) AND (value#10 = 2))
     :                 +- *(2) LocalTableScan [value#10]                                             :                 +- *(2) LocalTableScan [value#10]
     +- Project [col4#34]                                                                            +- Project [col4#34]
        +- Join Inner, (col2#6 = col4#34)                                                               +- Join Inner, (col2#6 = col4#34)
           :- Project [value#31 AS col4#34]                                                                :- Project [value#31 AS col4#34]
           :  +- LocalRelation [value#31]                                                                  :  +- LocalRelation [value#31]
           +- Project [col2#6]                                                                             +- Project [col2#6]
              +- Union false, false                                                                           +- Union false, false
                 :- Project [1 AS col2#6]                                                                        :- Project [1 AS col2#6]
                 :  +- LocalRelation [value#1]                                                                   :  +- LocalRelation [value#1]
                 +- Project [2 AS col2#15]                                                                       +- Project [2 AS col2#15]
                    +- LocalRelation [value#10]                                                                     +- LocalRelation [value#10]
    
    ```
    and so the result is wrong:
    ```
    +----+----+
    |col2|col4|
    +----+----+
    |   1|null|
    +----+----+
    ```
    
    After this PR foldable propagation will not happen incorrectly and the result is correct:
    ```
    +----+----+
    |col2|col4|
    +----+----+
    |   2|   2|
    +----+----+
    ```
    
    ### Why are the changes needed?
    To fix a correctness issue.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, fixes a correctness issue.
    
    ### How was this patch tested?
    Existing and new UTs.
    
    Closes #29771 from peter-toth/SPARK-32635-fix-foldable-propagation.
    
    Authored-by: Peter Toth <pe...@gmail.com>
    Signed-off-by: Takeshi Yamamuro <ya...@apache.org>
---
 .../sql/catalyst/expressions/AttributeMap.scala    |   2 +
 .../sql/catalyst/expressions/AttributeMap.scala    |   2 +
 .../spark/sql/catalyst/optimizer/expressions.scala | 121 ++++++++++++---------
 .../org/apache/spark/sql/DataFrameSuite.scala      |  12 ++
 4 files changed, 88 insertions(+), 49 deletions(-)

diff --git a/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
index 75a8bec..42b92d4 100644
--- a/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
+++ b/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
@@ -26,6 +26,8 @@ object AttributeMap {
   def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = {
     new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap)
   }
+
+  def empty[A]: AttributeMap[A] = new AttributeMap(Map.empty)
 }
 
 class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)])
diff --git a/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
index 4caa3d0..e6b53e3 100644
--- a/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
+++ b/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
@@ -26,6 +26,8 @@ object AttributeMap {
   def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = {
     new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap)
   }
+
+  def empty[A]: AttributeMap[A] = new AttributeMap(Map.empty)
 }
 
 class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)])
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index b2fc393..c4e4b25 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -624,59 +624,82 @@ object NullPropagation extends Rule[LogicalPlan] {
  */
 object FoldablePropagation extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = {
-    var foldableMap = AttributeMap(plan.flatMap {
-      case Project(projectList, _) => projectList.collect {
-        case a: Alias if a.child.foldable => (a.toAttribute, a)
-      }
-      case _ => Nil
-    })
-    val replaceFoldable: PartialFunction[Expression, Expression] = {
-      case a: AttributeReference if foldableMap.contains(a) => foldableMap(a)
+    CleanupAliases(propagateFoldables(plan)._1)
+  }
+
+  private def propagateFoldables(plan: LogicalPlan): (LogicalPlan, AttributeMap[Alias]) = {
+    plan match {
+      case p: Project =>
+        val (newChild, foldableMap) = propagateFoldables(p.child)
+        val newProject =
+          replaceFoldable(p.withNewChildren(Seq(newChild)).asInstanceOf[Project], foldableMap)
+        val newFoldableMap = AttributeMap(newProject.projectList.collect {
+          case a: Alias if a.child.foldable => (a.toAttribute, a)
+        })
+        (newProject, newFoldableMap)
+
+      // We can not replace the attributes in `Expand.output`. If there are other non-leaf
+      // operators that have the `output` field, we should put them here too.
+      case e: Expand =>
+        val (newChild, foldableMap) = propagateFoldables(e.child)
+        val expandWithNewChildren = e.withNewChildren(Seq(newChild)).asInstanceOf[Expand]
+        val newExpand = if (foldableMap.isEmpty) {
+          expandWithNewChildren
+        } else {
+          val newProjections = expandWithNewChildren.projections.map(_.map(_.transform {
+            case a: AttributeReference if foldableMap.contains(a) => foldableMap(a)
+          }))
+          if (newProjections == expandWithNewChildren.projections) {
+            expandWithNewChildren
+          } else {
+            expandWithNewChildren.copy(projections = newProjections)
+          }
+        }
+        (newExpand, foldableMap)
+
+      case u: UnaryNode if canPropagateFoldables(u) =>
+        val (newChild, foldableMap) = propagateFoldables(u.child)
+        val newU = replaceFoldable(u.withNewChildren(Seq(newChild)), foldableMap)
+        (newU, foldableMap)
+
+      // Join derives the output attributes from its child while they are actually not the
+      // same attributes. For example, the output of outer join is not always picked from its
+      // children, but can also be null. We should exclude these miss-derived attributes when
+      // propagating the foldable expressions.
+      // TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes
+      // of outer join.
+      case j: Join =>
+        val (newChildren, foldableMaps) = j.children.map(propagateFoldables).unzip
+        val foldableMap = AttributeMap(
+          foldableMaps.foldLeft(Iterable.empty[(Attribute, Alias)])(_ ++ _.baseMap.values).toSeq)
+        val newJoin =
+          replaceFoldable(j.withNewChildren(newChildren).asInstanceOf[Join], foldableMap)
+        val missDerivedAttrsSet: AttributeSet = AttributeSet(newJoin.joinType match {
+          case _: InnerLike | LeftExistence(_) => Nil
+          case LeftOuter => newJoin.right.output
+          case RightOuter => newJoin.left.output
+          case FullOuter => newJoin.left.output ++ newJoin.right.output
+        })
+        val newFoldableMap = AttributeMap(foldableMap.baseMap.values.filterNot {
+          case (attr, _) => missDerivedAttrsSet.contains(attr)
+        }.toSeq)
+        (newJoin, newFoldableMap)
+
+      // For other plans, they are not safe to apply foldable propagation, and they should not
+      // propagate foldable expressions from children.
+      case o =>
+        val newOther = o.mapChildren(propagateFoldables(_)._1)
+        (newOther, AttributeMap.empty)
     }
+  }
 
+  private def replaceFoldable(plan: LogicalPlan, foldableMap: AttributeMap[Alias]): plan.type = {
     if (foldableMap.isEmpty) {
       plan
     } else {
-      CleanupAliases(plan.transformUp {
-        // We can only propagate foldables for a subset of unary nodes.
-        case u: UnaryNode if foldableMap.nonEmpty && canPropagateFoldables(u) =>
-          u.transformExpressions(replaceFoldable)
-
-        // Join derives the output attributes from its child while they are actually not the
-        // same attributes. For example, the output of outer join is not always picked from its
-        // children, but can also be null. We should exclude these miss-derived attributes when
-        // propagating the foldable expressions.
-        // TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes
-        // of outer join.
-        case j @ Join(left, right, joinType, _, _) if foldableMap.nonEmpty =>
-          val newJoin = j.transformExpressions(replaceFoldable)
-          val missDerivedAttrsSet: AttributeSet = AttributeSet(joinType match {
-            case _: InnerLike | LeftExistence(_) => Nil
-            case LeftOuter => right.output
-            case RightOuter => left.output
-            case FullOuter => left.output ++ right.output
-          })
-          foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot {
-            case (attr, _) => missDerivedAttrsSet.contains(attr)
-          }.toSeq)
-          newJoin
-
-        // We can not replace the attributes in `Expand.output`. If there are other non-leaf
-        // operators that have the `output` field, we should put them here too.
-        case expand: Expand if foldableMap.nonEmpty =>
-          expand.copy(projections = expand.projections.map { projection =>
-            projection.map(_.transform(replaceFoldable))
-          })
-
-        // For other plans, they are not safe to apply foldable propagation, and they should not
-        // propagate foldable expressions from children.
-        case other if foldableMap.nonEmpty =>
-          val childrenOutputSet = AttributeSet(other.children.flatMap(_.output))
-          foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot {
-            case (attr, _) => childrenOutputSet.contains(attr)
-          }.toSeq)
-          other
-      })
+      plan transformExpressions {
+        case a: AttributeReference if foldableMap.contains(a) => foldableMap(a)
+      }
     }
   }
 
@@ -684,7 +707,7 @@ object FoldablePropagation extends Rule[LogicalPlan] {
    * List of all [[UnaryNode]]s which allow foldable propagation.
    */
   private def canPropagateFoldables(u: UnaryNode): Boolean = u match {
-    case _: Project => true
+    // Handling `Project` is moved to `propagateFoldables`.
     case _: Filter => true
     case _: SubqueryAlias => true
     case _: Aggregate => true
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index d95f09a..321f496 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -2555,6 +2555,18 @@ class DataFrameSuite extends QueryTest
     val df = Seq(0.0 -> -0.0).toDF("pos", "neg")
     checkAnswer(df.select($"pos" > $"neg"), Row(false))
   }
+
+  test("SPARK-32635: Replace references with foldables coming only from the node's children") {
+    val a = Seq("1").toDF("col1").withColumn("col2", lit("1"))
+    val b = Seq("2").toDF("col1").withColumn("col2", lit("2"))
+    val aub = a.union(b)
+    val c = aub.filter($"col1" === "2").cache()
+    val d = Seq("2").toDF("col4")
+    val r = d.join(aub, $"col2" === $"col4").select("col4")
+    val l = c.select("col2")
+    val df = l.join(r, $"col2" === $"col4", "LeftOuter")
+    checkAnswer(df, Row("2", "2"))
+  }
 }
 
 case class GroupByKey(a: Int, b: Int)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org