You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/12/29 01:25:30 UTC

(spark) branch master updated: [SPARK-46484][SQL][CONNECT] Make `resolveOperators*` helper functions keep the plan id

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 826f8d9eb04 [SPARK-46484][SQL][CONNECT] Make `resolveOperators*` helper functions keep the plan id
826f8d9eb04 is described below

commit 826f8d9eb040f2801a57544ea3e3e95c6f87154d
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Fri Dec 29 09:25:12 2023 +0800

    [SPARK-46484][SQL][CONNECT] Make `resolveOperators*` helper functions keep the plan id
    
    ### What changes were proposed in this pull request?
    1, make following helper functions keep the plan id in transformation:
    
    - `resolveOperatorsDownWithPruning`
    - `resolveOperatorsUpWithNewOutput`
    
    2, change the way to keep plan id in `ResolveNaturalAndUsingJoin`:
    before:
    ```
    Project <- tag hiddenOutputTag
      - Join <- tag PLAN_ID_TAG
    ```
    
    after:
    ```
    Project <- tag hiddenOutputTag & PLAN_ID_TAG
      - Join
    ```
    
    3, to verify this fix, this PR also reverts previous tags copying changes in the rules
    
    ### Why are the changes needed?
    we had make following rules keep the plan id:
    1, `ResolveNaturalAndUsingJoin` in https://github.com/apache/spark/commit/167bbca49c1c12ccd349d4330862c136b38d4522
    - using `resolveOperatorsUpWithPruning`, it set the tag `Project.hiddenOutputTag` internally, so `copyTagsFrom` (only works if `tags.isEmpty`) in `resolveOperatorsUpWithPruning` takes no effect
    
    2, `ExtractWindowExpressions` in https://github.com/apache/spark/commit/185a0a5a23958676e4236eaf9e4d78cdfd2dd2d7
    - using `resolveOperatorsDownWithPruning`, which doesn't copy tags
    
    3, `WidenSetOperationTypes` in https://github.com/apache/spark/commit/17c206fb71d03aefa75ecb87ca82772980dab954
    - using `resolveOperatorsUpWithNewOutput -> transformUpWithNewOutput`, which doesn't copy tags
    
    4, `ResolvePivot` in https://github.com/apache/spark/commit/1a89bdc60d55394a1a9d94d4fa69fa5ab8041671
    - using `resolveOperatorsWithPruning -> resolveOperatorsDownWithPruning`, which doesn't copy tags
    
    5, `CTESubstitution` in https://github.com/apache/spark/commit/79d1cded8555c5a0cc97b76747753785477eab8f
    - using both `resolveOperatorsDownWithPruning` and `resolveOperatorsUp -> resolveOperatorsUpWithPruning`, the former does't copy tags
    
    But plan id missing issue still keep popping up (see https://github.com/apache/spark/pull/44454), so this PR attempt to cover more cases by fixing the helper functions which are used to build the rules
    
    6, `ResolveUnpivot`
    - using `resolveOperatorsWithPruning -> resolveOperatorsDownWithPruning`, which doesn't copy tags
    
    7, `UnpivotCoercion`
    - using `resolveOperators -> resolveOperatorsWithPruning -> resolveOperatorsDownWithPruning`, which doesn't copy tags
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    ut
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #44462 from zhengruifeng/sql_res_op_keep.
    
    Authored-by: Ruifeng Zheng <ru...@apache.org>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 python/pyspark/sql/tests/test_dataframe.py         | 18 +++++++++
 .../spark/sql/catalyst/analysis/Analyzer.scala     | 45 +++++++++-------------
 .../sql/catalyst/analysis/CTESubstitution.scala    |  9 ++---
 .../spark/sql/catalyst/analysis/TypeCoercion.scala | 12 ++----
 .../spark/sql/catalyst/plans/QueryPlan.scala       |  3 ++
 .../catalyst/plans/logical/AnalysisHelper.scala    |  6 ++-
 6 files changed, 50 insertions(+), 43 deletions(-)

diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py
index 692cf77d9af..f1d690751ea 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -963,6 +963,24 @@ class DataFrameTestsMixin:
             ):
                 df.unpivot("id", ["int", "str"], "var", "val").collect()
 
+    def test_melt_groupby(self):
+        df = self.spark.createDataFrame(
+            [(1, 2, 3, 4, 5, 6)],
+            ["f1", "f2", "label", "pred", "model_version", "ts"],
+        )
+        self.assertEqual(
+            df.melt(
+                "model_version",
+                ["label", "f2"],
+                "f1",
+                "f2",
+            )
+            .groupby("f1")
+            .count()
+            .count(),
+            2,
+        )
+
     def test_observe(self):
         # SPARK-36263: tests the DataFrame.observe(Observation, *Column) method
         from pyspark.sql import Observation
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 94f6d334626..a57fd7a31d3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -765,7 +765,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
       case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved)
         || (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved))
         || !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p
-      case p @ Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) =>
+      case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) =>
         if (!RowOrdering.isOrderable(pivotColumn.dataType)) {
           throw QueryCompilationErrors.unorderablePivotColError(pivotColumn)
         }
@@ -829,9 +829,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
               Alias(ExtractValue(pivotAtt, Literal(i), resolver), outputName(value, aggregate))()
             }
           }
-          val newProject = Project(groupByExprsAttr ++ pivotOutputs, secondAgg)
-          newProject.copyTagsFrom(p)
-          newProject
+          Project(groupByExprsAttr ++ pivotOutputs, secondAgg)
         } else {
           val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value =>
             def ifExpr(e: Expression) = {
@@ -865,9 +863,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
               Alias(filteredAggregate, outputName(value, aggregate))()
             }
           }
-          val newAggregate = Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child)
-          newAggregate.copyTagsFrom(p)
-          newAggregate
+          Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child)
         }
     }
 
@@ -3264,9 +3260,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
 
         // Finally, generate output columns according to the original projectList.
         val finalProjectList = aggregateExprs.map(_.toAttribute)
-        val newProject = Project(finalProjectList, withWindow)
-        newProject.copyTagsFrom(f)
-        newProject
+        Project(finalProjectList, withWindow)
 
       case p: LogicalPlan if !p.childrenResolved => p
 
@@ -3284,9 +3278,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
 
         // Finally, generate output columns according to the original projectList.
         val finalProjectList = aggregateExprs.map(_.toAttribute)
-        val newProject = Project(finalProjectList, withWindow)
-        newProject.copyTagsFrom(a)
-        newProject
+        Project(finalProjectList, withWindow)
 
       // We only extract Window Expressions after all expressions of the Project
       // have been resolved, and lateral column aliases are properly handled first.
@@ -3303,9 +3295,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
 
         // Finally, generate output columns according to the original projectList.
         val finalProjectList = projectList.map(_.toAttribute)
-        val newProject = Project(finalProjectList, withWindow)
-        newProject.copyTagsFrom(p)
-        newProject
+        Project(finalProjectList, withWindow)
     }
   }
 
@@ -3461,14 +3451,20 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
       _.containsPattern(NATURAL_LIKE_JOIN), ruleId) {
       case j @ Join(left, right, UsingJoin(joinType, usingCols), _, hint)
           if left.resolved && right.resolved && j.duplicateResolved =>
-        commonNaturalJoinProcessing(left, right, joinType, usingCols, None, hint,
-          j.getTagValue(LogicalPlan.PLAN_ID_TAG))
+        val project = commonNaturalJoinProcessing(
+          left, right, joinType, usingCols, None, hint)
+        j.getTagValue(LogicalPlan.PLAN_ID_TAG)
+          .foreach(project.setTagValue(LogicalPlan.PLAN_ID_TAG, _))
+        project
       case j @ Join(left, right, NaturalJoin(joinType), condition, hint)
           if j.resolvedExceptNatural =>
         // find common column names from both sides
         val joinNames = left.output.map(_.name).intersect(right.output.map(_.name))
-        commonNaturalJoinProcessing(left, right, joinType, joinNames, condition, hint,
-          j.getTagValue(LogicalPlan.PLAN_ID_TAG))
+        val project = commonNaturalJoinProcessing(
+          left, right, joinType, joinNames, condition, hint)
+        j.getTagValue(LogicalPlan.PLAN_ID_TAG)
+          .foreach(project.setTagValue(LogicalPlan.PLAN_ID_TAG, _))
+        project
     }
   }
 
@@ -3516,8 +3512,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
       joinType: JoinType,
       joinNames: Seq[String],
       condition: Option[Expression],
-      hint: JoinHint,
-      planId: Option[Long] = None): LogicalPlan = {
+      hint: JoinHint): LogicalPlan = {
     import org.apache.spark.sql.catalyst.util._
 
     val leftKeys = joinNames.map { keyName =>
@@ -3570,13 +3565,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
         throw QueryExecutionErrors.unsupportedNaturalJoinTypeError(joinType)
     }
 
-    val newJoin = Join(left, right, joinType, newCondition, hint)
-    // retain the plan id used in Spark Connect
-    planId.foreach(newJoin.setTagValue(LogicalPlan.PLAN_ID_TAG, _))
-
     // use Project to hide duplicated common keys
     // propagate hidden columns from nested USING/NATURAL JOINs
-    val project = Project(projectList, newJoin)
+    val project = Project(projectList, Join(left, right, joinType, newCondition, hint))
     project.setTagValue(
       Project.hiddenOutputTag,
       hiddenList.map(_.markAsQualifiedAccessOnly()) ++
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala
index 173c9d44a2b..2982d8477fc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala
@@ -149,12 +149,10 @@ object CTESubstitution extends Rule[LogicalPlan] {
       plan: LogicalPlan,
       cteDefs: ArrayBuffer[CTERelationDef]): LogicalPlan = {
     plan.resolveOperatorsUp {
-      case cte @ UnresolvedWith(child, relations) =>
+      case UnresolvedWith(child, relations) =>
         val resolvedCTERelations =
           resolveCTERelations(relations, isLegacy = true, forceInline = false, Seq.empty, cteDefs)
-        val substituted = substituteCTE(child, alwaysInline = true, resolvedCTERelations)
-        substituted.copyTagsFrom(cte)
-        substituted
+        substituteCTE(child, alwaysInline = true, resolvedCTERelations)
     }
   }
 
@@ -204,7 +202,7 @@ object CTESubstitution extends Rule[LogicalPlan] {
     var firstSubstituted: Option[LogicalPlan] = None
     val newPlan = plan.resolveOperatorsDownWithPruning(
         _.containsAnyPattern(UNRESOLVED_WITH, PLAN_EXPRESSION)) {
-      case cte @ UnresolvedWith(child: LogicalPlan, relations) =>
+      case UnresolvedWith(child: LogicalPlan, relations) =>
         val resolvedCTERelations =
           resolveCTERelations(relations, isLegacy = false, forceInline, outerCTEDefs, cteDefs) ++
             outerCTEDefs
@@ -215,7 +213,6 @@ object CTESubstitution extends Rule[LogicalPlan] {
         if (firstSubstituted.isEmpty) {
           firstSubstituted = Some(substituted)
         }
-        substituted.copyTagsFrom(cte)
         substituted
 
       case other =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index c5e98683c74..56e8843fda5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -266,10 +266,8 @@ abstract class TypeCoercionBase {
             s -> Nil
           } else {
             assert(newChildren.length == 2)
-            val newExcept = Except(newChildren.head, newChildren.last, isAll)
-            newExcept.copyTagsFrom(s)
             val attrMapping = left.output.zip(newChildren.head.output)
-            newExcept -> attrMapping
+            Except(newChildren.head, newChildren.last, isAll) -> attrMapping
           }
 
         case s @ Intersect(left, right, isAll) if s.childrenResolved &&
@@ -279,10 +277,8 @@ abstract class TypeCoercionBase {
             s -> Nil
           } else {
             assert(newChildren.length == 2)
-            val newIntersect = Intersect(newChildren.head, newChildren.last, isAll)
-            newIntersect.copyTagsFrom(s)
             val attrMapping = left.output.zip(newChildren.head.output)
-            newIntersect -> attrMapping
+            Intersect(newChildren.head, newChildren.last, isAll) -> attrMapping
           }
 
         case s: Union if s.childrenResolved && !s.byName &&
@@ -292,9 +288,7 @@ abstract class TypeCoercionBase {
             s -> Nil
           } else {
             val attrMapping = s.children.head.output.zip(newChildren.head.output)
-            val newUnion = s.copy(children = newChildren)
-            newUnion.copyTagsFrom(s)
-            newUnion -> attrMapping
+            s.copy(children = newChildren) -> attrMapping
           }
       }
     }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index ef7cd7401f2..2a62ea1feb0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -361,6 +361,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
         } else {
           transferAttrMapping ++ newOtherAttrMapping
         }
+        if (!(plan eq planAfterRule)) {
+          planAfterRule.copyTagsFrom(plan)
+        }
         planAfterRule -> resultAttrMapping.toSeq
       }
     }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala
index 56f6b116759..45a20cbe3aa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala
@@ -177,10 +177,14 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan =>
             self.markRuleAsIneffective(ruleId)
             self
           } else {
+            rewritten_plan.copyTagsFrom(self)
             rewritten_plan
           }
         } else {
-          afterRule.mapChildren(_.resolveOperatorsDownWithPruning(cond, ruleId)(rule))
+          val newPlan = afterRule
+            .mapChildren(_.resolveOperatorsDownWithPruning(cond, ruleId)(rule))
+          newPlan.copyTagsFrom(self)
+          newPlan
         }
       }
     } else {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org