You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ya...@apache.org on 2020/09/17 23:20:36 UTC
[spark] branch master updated: [SPARK-32635][SQL] Fix foldable
propagation
This is an automated email from the ASF dual-hosted git repository.
yamamuro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4ced588 [SPARK-32635][SQL] Fix foldable propagation
4ced588 is described below
commit 4ced58862c707aa916f7a55d15c3887c94c9b210
Author: Peter Toth <pe...@gmail.com>
AuthorDate: Fri Sep 18 08:17:23 2020 +0900
[SPARK-32635][SQL] Fix foldable propagation
### What changes were proposed in this pull request?
This PR rewrites `FoldablePropagation` rule to replace attribute references in a node with foldables coming only from the node's children.
Before this PR in the case of this example (with setting`spark.sql.optimizer.excludedRules=org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation`):
```scala
val a = Seq("1").toDF("col1").withColumn("col2", lit("1"))
val b = Seq("2").toDF("col1").withColumn("col2", lit("2"))
val aub = a.union(b)
val c = aub.filter($"col1" === "2").cache()
val d = Seq("2").toDF( "col4")
val r = d.join(aub, $"col2" === $"col4").select("col4")
val l = c.select("col2")
val df = l.join(r, $"col2" === $"col4", "LeftOuter")
df.show()
```
foldable propagation happens incorrectly:
```
Join LeftOuter, (col2#6 = col4#34) Join LeftOuter, (col2#6 = col4#34)
!:- Project [col2#6] :- Project [1 AS col2#6]
: +- InMemoryRelation [col1#4, col2#6], StorageLevel(disk, memory, deserialized, 1 replicas) : +- InMemoryRelation [col1#4, col2#6], StorageLevel(disk, memory, deserialized, 1 replicas)
: +- Union : +- Union
: :- *(1) Project [value#1 AS col1#4, 1 AS col2#6] : :- *(1) Project [value#1 AS col1#4, 1 AS col2#6]
: : +- *(1) Filter (isnotnull(value#1) AND (value#1 = 2)) : : +- *(1) Filter (isnotnull(value#1) AND (value#1 = 2))
: : +- *(1) LocalTableScan [value#1] : : +- *(1) LocalTableScan [value#1]
: +- *(2) Project [value#10 AS col1#13, 2 AS col2#15] : +- *(2) Project [value#10 AS col1#13, 2 AS col2#15]
: +- *(2) Filter (isnotnull(value#10) AND (value#10 = 2)) : +- *(2) Filter (isnotnull(value#10) AND (value#10 = 2))
: +- *(2) LocalTableScan [value#10] : +- *(2) LocalTableScan [value#10]
+- Project [col4#34] +- Project [col4#34]
+- Join Inner, (col2#6 = col4#34) +- Join Inner, (col2#6 = col4#34)
:- Project [value#31 AS col4#34] :- Project [value#31 AS col4#34]
: +- LocalRelation [value#31] : +- LocalRelation [value#31]
+- Project [col2#6] +- Project [col2#6]
+- Union false, false +- Union false, false
:- Project [1 AS col2#6] :- Project [1 AS col2#6]
: +- LocalRelation [value#1] : +- LocalRelation [value#1]
+- Project [2 AS col2#15] +- Project [2 AS col2#15]
+- LocalRelation [value#10] +- LocalRelation [value#10]
```
and so the result is wrong:
```
+----+----+
|col2|col4|
+----+----+
| 1|null|
+----+----+
```
After this PR foldable propagation will not happen incorrectly and the result is correct:
```
+----+----+
|col2|col4|
+----+----+
| 2| 2|
+----+----+
```
### Why are the changes needed?
To fix a correctness issue.
### Does this PR introduce _any_ user-facing change?
Yes, fixes a correctness issue.
### How was this patch tested?
Existing and new UTs.
Closes #29771 from peter-toth/SPARK-32635-fix-foldable-propagation.
Authored-by: Peter Toth <pe...@gmail.com>
Signed-off-by: Takeshi Yamamuro <ya...@apache.org>
---
.../sql/catalyst/expressions/AttributeMap.scala | 2 +
.../sql/catalyst/expressions/AttributeMap.scala | 2 +
.../spark/sql/catalyst/optimizer/expressions.scala | 121 ++++++++++++---------
.../org/apache/spark/sql/DataFrameSuite.scala | 12 ++
4 files changed, 88 insertions(+), 49 deletions(-)
diff --git a/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
index 75a8bec..42b92d4 100644
--- a/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
+++ b/sql/catalyst/src/main/scala-2.12/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
@@ -26,6 +26,8 @@ object AttributeMap {
def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = {
new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap)
}
+
+ def empty[A]: AttributeMap[A] = new AttributeMap(Map.empty)
}
class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)])
diff --git a/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
index 4caa3d0..e6b53e3 100644
--- a/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
+++ b/sql/catalyst/src/main/scala-2.13/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
@@ -26,6 +26,8 @@ object AttributeMap {
def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = {
new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap)
}
+
+ def empty[A]: AttributeMap[A] = new AttributeMap(Map.empty)
}
class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)])
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index b2fc393..c4e4b25 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -624,59 +624,82 @@ object NullPropagation extends Rule[LogicalPlan] {
*/
object FoldablePropagation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
- var foldableMap = AttributeMap(plan.flatMap {
- case Project(projectList, _) => projectList.collect {
- case a: Alias if a.child.foldable => (a.toAttribute, a)
- }
- case _ => Nil
- })
- val replaceFoldable: PartialFunction[Expression, Expression] = {
- case a: AttributeReference if foldableMap.contains(a) => foldableMap(a)
+ CleanupAliases(propagateFoldables(plan)._1)
+ }
+
+ private def propagateFoldables(plan: LogicalPlan): (LogicalPlan, AttributeMap[Alias]) = {
+ plan match {
+ case p: Project =>
+ val (newChild, foldableMap) = propagateFoldables(p.child)
+ val newProject =
+ replaceFoldable(p.withNewChildren(Seq(newChild)).asInstanceOf[Project], foldableMap)
+ val newFoldableMap = AttributeMap(newProject.projectList.collect {
+ case a: Alias if a.child.foldable => (a.toAttribute, a)
+ })
+ (newProject, newFoldableMap)
+
+ // We can not replace the attributes in `Expand.output`. If there are other non-leaf
+ // operators that have the `output` field, we should put them here too.
+ case e: Expand =>
+ val (newChild, foldableMap) = propagateFoldables(e.child)
+ val expandWithNewChildren = e.withNewChildren(Seq(newChild)).asInstanceOf[Expand]
+ val newExpand = if (foldableMap.isEmpty) {
+ expandWithNewChildren
+ } else {
+ val newProjections = expandWithNewChildren.projections.map(_.map(_.transform {
+ case a: AttributeReference if foldableMap.contains(a) => foldableMap(a)
+ }))
+ if (newProjections == expandWithNewChildren.projections) {
+ expandWithNewChildren
+ } else {
+ expandWithNewChildren.copy(projections = newProjections)
+ }
+ }
+ (newExpand, foldableMap)
+
+ case u: UnaryNode if canPropagateFoldables(u) =>
+ val (newChild, foldableMap) = propagateFoldables(u.child)
+ val newU = replaceFoldable(u.withNewChildren(Seq(newChild)), foldableMap)
+ (newU, foldableMap)
+
+ // Join derives the output attributes from its child while they are actually not the
+ // same attributes. For example, the output of outer join is not always picked from its
+ // children, but can also be null. We should exclude these miss-derived attributes when
+ // propagating the foldable expressions.
+ // TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes
+ // of outer join.
+ case j: Join =>
+ val (newChildren, foldableMaps) = j.children.map(propagateFoldables).unzip
+ val foldableMap = AttributeMap(
+ foldableMaps.foldLeft(Iterable.empty[(Attribute, Alias)])(_ ++ _.baseMap.values).toSeq)
+ val newJoin =
+ replaceFoldable(j.withNewChildren(newChildren).asInstanceOf[Join], foldableMap)
+ val missDerivedAttrsSet: AttributeSet = AttributeSet(newJoin.joinType match {
+ case _: InnerLike | LeftExistence(_) => Nil
+ case LeftOuter => newJoin.right.output
+ case RightOuter => newJoin.left.output
+ case FullOuter => newJoin.left.output ++ newJoin.right.output
+ })
+ val newFoldableMap = AttributeMap(foldableMap.baseMap.values.filterNot {
+ case (attr, _) => missDerivedAttrsSet.contains(attr)
+ }.toSeq)
+ (newJoin, newFoldableMap)
+
+ // For other plans, they are not safe to apply foldable propagation, and they should not
+ // propagate foldable expressions from children.
+ case o =>
+ val newOther = o.mapChildren(propagateFoldables(_)._1)
+ (newOther, AttributeMap.empty)
}
+ }
+ private def replaceFoldable(plan: LogicalPlan, foldableMap: AttributeMap[Alias]): plan.type = {
if (foldableMap.isEmpty) {
plan
} else {
- CleanupAliases(plan.transformUp {
- // We can only propagate foldables for a subset of unary nodes.
- case u: UnaryNode if foldableMap.nonEmpty && canPropagateFoldables(u) =>
- u.transformExpressions(replaceFoldable)
-
- // Join derives the output attributes from its child while they are actually not the
- // same attributes. For example, the output of outer join is not always picked from its
- // children, but can also be null. We should exclude these miss-derived attributes when
- // propagating the foldable expressions.
- // TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes
- // of outer join.
- case j @ Join(left, right, joinType, _, _) if foldableMap.nonEmpty =>
- val newJoin = j.transformExpressions(replaceFoldable)
- val missDerivedAttrsSet: AttributeSet = AttributeSet(joinType match {
- case _: InnerLike | LeftExistence(_) => Nil
- case LeftOuter => right.output
- case RightOuter => left.output
- case FullOuter => left.output ++ right.output
- })
- foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot {
- case (attr, _) => missDerivedAttrsSet.contains(attr)
- }.toSeq)
- newJoin
-
- // We can not replace the attributes in `Expand.output`. If there are other non-leaf
- // operators that have the `output` field, we should put them here too.
- case expand: Expand if foldableMap.nonEmpty =>
- expand.copy(projections = expand.projections.map { projection =>
- projection.map(_.transform(replaceFoldable))
- })
-
- // For other plans, they are not safe to apply foldable propagation, and they should not
- // propagate foldable expressions from children.
- case other if foldableMap.nonEmpty =>
- val childrenOutputSet = AttributeSet(other.children.flatMap(_.output))
- foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot {
- case (attr, _) => childrenOutputSet.contains(attr)
- }.toSeq)
- other
- })
+ plan transformExpressions {
+ case a: AttributeReference if foldableMap.contains(a) => foldableMap(a)
+ }
}
}
@@ -684,7 +707,7 @@ object FoldablePropagation extends Rule[LogicalPlan] {
* List of all [[UnaryNode]]s which allow foldable propagation.
*/
private def canPropagateFoldables(u: UnaryNode): Boolean = u match {
- case _: Project => true
+ // Handling `Project` is moved to `propagateFoldables`.
case _: Filter => true
case _: SubqueryAlias => true
case _: Aggregate => true
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index d95f09a..321f496 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -2555,6 +2555,18 @@ class DataFrameSuite extends QueryTest
val df = Seq(0.0 -> -0.0).toDF("pos", "neg")
checkAnswer(df.select($"pos" > $"neg"), Row(false))
}
+
+ test("SPARK-32635: Replace references with foldables coming only from the node's children") {
+ val a = Seq("1").toDF("col1").withColumn("col2", lit("1"))
+ val b = Seq("2").toDF("col1").withColumn("col2", lit("2"))
+ val aub = a.union(b)
+ val c = aub.filter($"col1" === "2").cache()
+ val d = Seq("2").toDF("col4")
+ val r = d.join(aub, $"col2" === $"col4").select("col4")
+ val l = c.select("col2")
+ val df = l.join(r, $"col2" === $"col4", "LeftOuter")
+ checkAnswer(df, Row("2", "2"))
+ }
}
case class GroupByKey(a: Int, b: Int)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org