You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2023/02/28 07:53:33 UTC

[spark] branch branch-3.4 updated: [SPARK-42548][SQL] Add ReferenceAllColumns to skip rewriting attributes

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new a2d5c9c55d2 [SPARK-42548][SQL] Add ReferenceAllColumns to skip rewriting attributes
a2d5c9c55d2 is described below

commit a2d5c9c55d2adea4a914cc647e1a100b8b8cee5d
Author: ulysses-you <ul...@gmail.com>
AuthorDate: Tue Feb 28 15:52:53 2023 +0800

    [SPARK-42548][SQL] Add ReferenceAllColumns to skip rewriting attributes
    
    ### What changes were proposed in this pull request?
    
    Add a new trait `ReferenceAllColumns ` that overrides `references` using children output. Then we can skip it during rewriting attributes in transformUpWithNewOutput.
    
    ### Why are the changes needed?
    
    There are two reasons with this new trait:
    
    1. it's dangerous to call `references` on an unresolved plan that all of references come from children
    2. it's unnecessary to rewrite its attributes that all of references come from children
    
    ### Does this PR introduce _any_ user-facing change?
    
    prevent potential bug
    
    ### How was this patch tested?
    
    add test and pass CI
    
    Closes #40154 from ulysses-you/references.
    
    Authored-by: ulysses-you <ul...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit db0e8224e1e4c928fa2f7046ae13b6aad2b8cad6)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../spark/sql/catalyst/plans/QueryPlan.scala       | 37 +++++++++++++---------
 .../sql/catalyst/plans/ReferenceAllColumns.scala   | 34 ++++++++++++++++++++
 .../plans/logical/ScriptTransformation.scala       |  8 ++---
 .../spark/sql/catalyst/plans/logical/object.scala  |  8 ++---
 .../sql/catalyst/analysis/TypeCoercionSuite.scala  | 18 +++++++++++
 .../org/apache/spark/sql/execution/objects.scala   |  8 ++---
 6 files changed, 81 insertions(+), 32 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 90d1bd805cb..ae5e9789dd9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -297,21 +297,28 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
           newChild
         }
 
-        val attrMappingForCurrentPlan = attrMapping.filter {
-          // The `attrMappingForCurrentPlan` is used to replace the attributes of the
-          // current `plan`, so the `oldAttr` must be part of `plan.references`.
-          case (oldAttr, _) => plan.references.contains(oldAttr)
-        }
-
-        if (attrMappingForCurrentPlan.nonEmpty) {
-          assert(!attrMappingForCurrentPlan.groupBy(_._1.exprId)
-            .exists(_._2.map(_._2.exprId).distinct.length > 1),
-            "Found duplicate rewrite attributes")
-
-          val attributeRewrites = AttributeMap(attrMappingForCurrentPlan)
-          // Using attrMapping from the children plans to rewrite their parent node.
-          // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
-          newPlan = newPlan.rewriteAttrs(attributeRewrites)
+        plan match {
+          case _: ReferenceAllColumns[_] =>
+            // It's dangerous to call `references` on an unresolved `ReferenceAllColumns`, and
+            // it's unnecessary to rewrite its attributes that all of references come from children
+
+          case _ =>
+            val attrMappingForCurrentPlan = attrMapping.filter {
+              // The `attrMappingForCurrentPlan` is used to replace the attributes of the
+              // current `plan`, so the `oldAttr` must be part of `plan.references`.
+              case (oldAttr, _) => plan.references.contains(oldAttr)
+            }
+
+            if (attrMappingForCurrentPlan.nonEmpty) {
+              assert(!attrMappingForCurrentPlan.groupBy(_._1.exprId)
+                .exists(_._2.map(_._2.exprId).distinct.length > 1),
+                "Found duplicate rewrite attributes")
+
+              val attributeRewrites = AttributeMap(attrMappingForCurrentPlan)
+              // Using attrMapping from the children plans to rewrite their parent node.
+              // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
+              newPlan = newPlan.rewriteAttrs(attributeRewrites)
+            }
         }
 
         val (planAfterRule, newAttrMapping) = CurrentOrigin.withOrigin(origin) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/ReferenceAllColumns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/ReferenceAllColumns.scala
new file mode 100644
index 00000000000..613e2a06f49
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/ReferenceAllColumns.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans
+
+import org.apache.spark.sql.catalyst.expressions.AttributeSet
+
+/**
+ * A trait that overrides `references` using children output.
+ *
+ * It's unnecessary to rewrite attributes for `ReferenceAllColumns` since all of references
+ * come from it's children.
+ *
+ * Note, the only used place is at [[QueryPlan.transformUpWithNewOutput]].
+ */
+trait ReferenceAllColumns[PlanType <: QueryPlan[PlanType]] { self: QueryPlan[PlanType] =>
+
+  @transient
+  override final lazy val references: AttributeSet = AttributeSet(children.flatMap(_.outputSet))
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala
index 5fe5dc37371..e6ebf981bc4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala
@@ -17,7 +17,8 @@
 
 package org.apache.spark.sql.catalyst.plans.logical
 
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet}
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns
 
 /**
  * Transforms the input by forking and running the specified script.
@@ -30,10 +31,7 @@ case class ScriptTransformation(
     script: String,
     output: Seq[Attribute],
     child: LogicalPlan,
-    ioschema: ScriptInputOutputSchema) extends UnaryNode {
-  @transient
-  override lazy val references: AttributeSet = AttributeSet(child.output)
-
+    ioschema: ScriptInputOutputSchema) extends UnaryNode with ReferenceAllColumns[LogicalPlan] {
   override protected def withNewChildInternal(newChild: LogicalPlan): ScriptTransformation =
     copy(child = newChild)
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index b27c650cfb2..c6a4779374d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
 import org.apache.spark.sql.catalyst.encoders._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.objects.Invoke
+import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns
 import org.apache.spark.sql.catalyst.trees.TreePattern._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
@@ -64,13 +65,8 @@ trait ObjectProducer extends LogicalPlan {
  * A trait for logical operators that consumes domain objects as input.
  * The output of its child must be a single-field row containing the input object.
  */
-trait ObjectConsumer extends UnaryNode {
+trait ObjectConsumer extends UnaryNode with ReferenceAllColumns[LogicalPlan] {
   assert(child.output.length == 1)
-
-  // This operator always need all columns of its child, even it doesn't reference to.
-  @transient
-  override lazy val references: AttributeSet = child.outputSet
-
   def inputObjAttr: Attribute = child.output.head
 }
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index adce553d194..e30cce23136 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion._
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
 import org.apache.spark.sql.internal.SQLConf
@@ -1740,6 +1741,16 @@ class TypeCoercionSuite extends TypeCoercionSuiteBase {
       }
     }
   }
+
+  test("SPARK-32638: Add ReferenceAllColumns to skip rewriting attributes") {
+    val t1 = LocalRelation(AttributeReference("c", DecimalType(1, 0))())
+    val t2 = LocalRelation(AttributeReference("c", DecimalType(2, 0))())
+    val unresolved = t1.union(t2).select(UnresolvedStar(None))
+    val referenceAllColumns = FakeReferenceAllColumns(unresolved)
+    val wp1 = widenSetOperationTypes(referenceAllColumns.select(t1.output.head))
+    assert(wp1.isInstanceOf[Project])
+    assert(wp1.expressions.forall(!_.exists(_ == t1.output.head)))
+  }
 }
 
 
@@ -1798,3 +1809,10 @@ object TypeCoercionSuite {
       copy(left = newLeft, right = newRight)
   }
 }
+
+case class FakeReferenceAllColumns(child: LogicalPlan)
+  extends UnaryNode with ReferenceAllColumns[LogicalPlan] {
+  override def output: Seq[Attribute] = child.output
+  override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan =
+    copy(child = newChild)
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index bda592ff929..c8d575016fc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.objects.Invoke
+import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns
 import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, FunctionUtils, LogicalGroupState}
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.python.BatchIterator
@@ -58,13 +59,8 @@ trait ObjectProducerExec extends SparkPlan {
 /**
  * Physical version of `ObjectConsumer`.
  */
-trait ObjectConsumerExec extends UnaryExecNode {
+trait ObjectConsumerExec extends UnaryExecNode with ReferenceAllColumns[SparkPlan] {
   assert(child.output.length == 1)
-
-  // This operator always need all columns of its child, even it doesn't reference to.
-  @transient
-  override lazy val references: AttributeSet = child.outputSet
-
   def inputObjectType: DataType = child.output.head.dataType
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org