You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2023/02/28 07:53:33 UTC
[spark] branch branch-3.4 updated: [SPARK-42548][SQL] Add ReferenceAllColumns to skip rewriting attributes
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new a2d5c9c55d2 [SPARK-42548][SQL] Add ReferenceAllColumns to skip rewriting attributes
a2d5c9c55d2 is described below
commit a2d5c9c55d2adea4a914cc647e1a100b8b8cee5d
Author: ulysses-you <ul...@gmail.com>
AuthorDate: Tue Feb 28 15:52:53 2023 +0800
[SPARK-42548][SQL] Add ReferenceAllColumns to skip rewriting attributes
### What changes were proposed in this pull request?
Add a new trait `ReferenceAllColumns ` that overrides `references` using children output. Then we can skip it during rewriting attributes in transformUpWithNewOutput.
### Why are the changes needed?
There are two reasons with this new trait:
1. it's dangerous to call `references` on an unresolved plan that all of references come from children
2. it's unnecessary to rewrite its attributes that all of references come from children
### Does this PR introduce _any_ user-facing change?
prevent potential bug
### How was this patch tested?
add test and pass CI
Closes #40154 from ulysses-you/references.
Authored-by: ulysses-you <ul...@gmail.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
(cherry picked from commit db0e8224e1e4c928fa2f7046ae13b6aad2b8cad6)
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../spark/sql/catalyst/plans/QueryPlan.scala | 37 +++++++++++++---------
.../sql/catalyst/plans/ReferenceAllColumns.scala | 34 ++++++++++++++++++++
.../plans/logical/ScriptTransformation.scala | 8 ++---
.../spark/sql/catalyst/plans/logical/object.scala | 8 ++---
.../sql/catalyst/analysis/TypeCoercionSuite.scala | 18 +++++++++++
.../org/apache/spark/sql/execution/objects.scala | 8 ++---
6 files changed, 81 insertions(+), 32 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 90d1bd805cb..ae5e9789dd9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -297,21 +297,28 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
newChild
}
- val attrMappingForCurrentPlan = attrMapping.filter {
- // The `attrMappingForCurrentPlan` is used to replace the attributes of the
- // current `plan`, so the `oldAttr` must be part of `plan.references`.
- case (oldAttr, _) => plan.references.contains(oldAttr)
- }
-
- if (attrMappingForCurrentPlan.nonEmpty) {
- assert(!attrMappingForCurrentPlan.groupBy(_._1.exprId)
- .exists(_._2.map(_._2.exprId).distinct.length > 1),
- "Found duplicate rewrite attributes")
-
- val attributeRewrites = AttributeMap(attrMappingForCurrentPlan)
- // Using attrMapping from the children plans to rewrite their parent node.
- // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
- newPlan = newPlan.rewriteAttrs(attributeRewrites)
+ plan match {
+ case _: ReferenceAllColumns[_] =>
+ // It's dangerous to call `references` on an unresolved `ReferenceAllColumns`, and
+ // it's unnecessary to rewrite its attributes that all of references come from children
+
+ case _ =>
+ val attrMappingForCurrentPlan = attrMapping.filter {
+ // The `attrMappingForCurrentPlan` is used to replace the attributes of the
+ // current `plan`, so the `oldAttr` must be part of `plan.references`.
+ case (oldAttr, _) => plan.references.contains(oldAttr)
+ }
+
+ if (attrMappingForCurrentPlan.nonEmpty) {
+ assert(!attrMappingForCurrentPlan.groupBy(_._1.exprId)
+ .exists(_._2.map(_._2.exprId).distinct.length > 1),
+ "Found duplicate rewrite attributes")
+
+ val attributeRewrites = AttributeMap(attrMappingForCurrentPlan)
+ // Using attrMapping from the children plans to rewrite their parent node.
+ // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
+ newPlan = newPlan.rewriteAttrs(attributeRewrites)
+ }
}
val (planAfterRule, newAttrMapping) = CurrentOrigin.withOrigin(origin) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/ReferenceAllColumns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/ReferenceAllColumns.scala
new file mode 100644
index 00000000000..613e2a06f49
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/ReferenceAllColumns.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans
+
+import org.apache.spark.sql.catalyst.expressions.AttributeSet
+
+/**
+ * A trait that overrides `references` using children output.
+ *
+ * It's unnecessary to rewrite attributes for `ReferenceAllColumns` since all of references
+ * come from it's children.
+ *
+ * Note, the only used place is at [[QueryPlan.transformUpWithNewOutput]].
+ */
+trait ReferenceAllColumns[PlanType <: QueryPlan[PlanType]] { self: QueryPlan[PlanType] =>
+
+ @transient
+ override final lazy val references: AttributeSet = AttributeSet(children.flatMap(_.outputSet))
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala
index 5fe5dc37371..e6ebf981bc4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala
@@ -17,7 +17,8 @@
package org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet}
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns
/**
* Transforms the input by forking and running the specified script.
@@ -30,10 +31,7 @@ case class ScriptTransformation(
script: String,
output: Seq[Attribute],
child: LogicalPlan,
- ioschema: ScriptInputOutputSchema) extends UnaryNode {
- @transient
- override lazy val references: AttributeSet = AttributeSet(child.output)
-
+ ioschema: ScriptInputOutputSchema) extends UnaryNode with ReferenceAllColumns[LogicalPlan] {
override protected def withNewChildInternal(newChild: LogicalPlan): ScriptTransformation =
copy(child = newChild)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index b27c650cfb2..c6a4779374d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
+import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
@@ -64,13 +65,8 @@ trait ObjectProducer extends LogicalPlan {
* A trait for logical operators that consumes domain objects as input.
* The output of its child must be a single-field row containing the input object.
*/
-trait ObjectConsumer extends UnaryNode {
+trait ObjectConsumer extends UnaryNode with ReferenceAllColumns[LogicalPlan] {
assert(child.output.length == 1)
-
- // This operator always need all columns of its child, even it doesn't reference to.
- @transient
- override lazy val references: AttributeSet = child.outputSet
-
def inputObjAttr: Attribute = child.output.head
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index adce553d194..e30cce23136 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.internal.SQLConf
@@ -1740,6 +1741,16 @@ class TypeCoercionSuite extends TypeCoercionSuiteBase {
}
}
}
+
+ test("SPARK-32638: Add ReferenceAllColumns to skip rewriting attributes") {
+ val t1 = LocalRelation(AttributeReference("c", DecimalType(1, 0))())
+ val t2 = LocalRelation(AttributeReference("c", DecimalType(2, 0))())
+ val unresolved = t1.union(t2).select(UnresolvedStar(None))
+ val referenceAllColumns = FakeReferenceAllColumns(unresolved)
+ val wp1 = widenSetOperationTypes(referenceAllColumns.select(t1.output.head))
+ assert(wp1.isInstanceOf[Project])
+ assert(wp1.expressions.forall(!_.exists(_ == t1.output.head)))
+ }
}
@@ -1798,3 +1809,10 @@ object TypeCoercionSuite {
copy(left = newLeft, right = newRight)
}
}
+
+case class FakeReferenceAllColumns(child: LogicalPlan)
+ extends UnaryNode with ReferenceAllColumns[LogicalPlan] {
+ override def output: Seq[Attribute] = child.output
+ override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan =
+ copy(child = newChild)
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index bda592ff929..c8d575016fc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
+import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, FunctionUtils, LogicalGroupState}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.python.BatchIterator
@@ -58,13 +59,8 @@ trait ObjectProducerExec extends SparkPlan {
/**
* Physical version of `ObjectConsumer`.
*/
-trait ObjectConsumerExec extends UnaryExecNode {
+trait ObjectConsumerExec extends UnaryExecNode with ReferenceAllColumns[SparkPlan] {
assert(child.output.length == 1)
-
- // This operator always need all columns of its child, even it doesn't reference to.
- @transient
- override lazy val references: AttributeSet = child.outputSet
-
def inputObjectType: DataType = child.output.head.dataType
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org