You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/12/29 01:25:30 UTC
(spark) branch master updated: [SPARK-46484][SQL][CONNECT] Make `resolveOperators*` helper functions keep the plan id
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 826f8d9eb04 [SPARK-46484][SQL][CONNECT] Make `resolveOperators*` helper functions keep the plan id
826f8d9eb04 is described below
commit 826f8d9eb040f2801a57544ea3e3e95c6f87154d
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Fri Dec 29 09:25:12 2023 +0800
[SPARK-46484][SQL][CONNECT] Make `resolveOperators*` helper functions keep the plan id
### What changes were proposed in this pull request?
1, make following helper functions keep the plan id in transformation:
- `resolveOperatorsDownWithPruning`
- `resolveOperatorsUpWithNewOutput`
2, change the way to keep plan id in `ResolveNaturalAndUsingJoin`:
before:
```
Project <- tag hiddenOutputTag
- Join <- tag PLAN_ID_TAG
```
after:
```
Project <- tag hiddenOutputTag & PLAN_ID_TAG
- Join
```
3, to verify this fix, this PR also reverts previous tags copying changes in the rules
### Why are the changes needed?
we had make following rules keep the plan id:
1, `ResolveNaturalAndUsingJoin` in https://github.com/apache/spark/commit/167bbca49c1c12ccd349d4330862c136b38d4522
- using `resolveOperatorsUpWithPruning`, it set the tag `Project.hiddenOutputTag` internally, so `copyTagsFrom` (only works if `tags.isEmpty`) in `resolveOperatorsUpWithPruning` takes no effect
2, `ExtractWindowExpressions` in https://github.com/apache/spark/commit/185a0a5a23958676e4236eaf9e4d78cdfd2dd2d7
- using `resolveOperatorsDownWithPruning`, which doesn't copy tags
3, `WidenSetOperationTypes` in https://github.com/apache/spark/commit/17c206fb71d03aefa75ecb87ca82772980dab954
- using `resolveOperatorsUpWithNewOutput -> transformUpWithNewOutput`, which doesn't copy tags
4, `ResolvePivot` in https://github.com/apache/spark/commit/1a89bdc60d55394a1a9d94d4fa69fa5ab8041671
- using `resolveOperatorsWithPruning -> resolveOperatorsDownWithPruning`, which doesn't copy tags
5, `CTESubstitution` in https://github.com/apache/spark/commit/79d1cded8555c5a0cc97b76747753785477eab8f
- using both `resolveOperatorsDownWithPruning` and `resolveOperatorsUp -> resolveOperatorsUpWithPruning`, the former does't copy tags
But plan id missing issue still keep popping up (see https://github.com/apache/spark/pull/44454), so this PR attempt to cover more cases by fixing the helper functions which are used to build the rules
6, `ResolveUnpivot`
- using `resolveOperatorsWithPruning -> resolveOperatorsDownWithPruning`, which doesn't copy tags
7, `UnpivotCoercion`
- using `resolveOperators -> resolveOperatorsWithPruning -> resolveOperatorsDownWithPruning`, which doesn't copy tags
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
ut
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #44462 from zhengruifeng/sql_res_op_keep.
Authored-by: Ruifeng Zheng <ru...@apache.org>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
python/pyspark/sql/tests/test_dataframe.py | 18 +++++++++
.../spark/sql/catalyst/analysis/Analyzer.scala | 45 +++++++++-------------
.../sql/catalyst/analysis/CTESubstitution.scala | 9 ++---
.../spark/sql/catalyst/analysis/TypeCoercion.scala | 12 ++----
.../spark/sql/catalyst/plans/QueryPlan.scala | 3 ++
.../catalyst/plans/logical/AnalysisHelper.scala | 6 ++-
6 files changed, 50 insertions(+), 43 deletions(-)
diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py
index 692cf77d9af..f1d690751ea 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -963,6 +963,24 @@ class DataFrameTestsMixin:
):
df.unpivot("id", ["int", "str"], "var", "val").collect()
+ def test_melt_groupby(self):
+ df = self.spark.createDataFrame(
+ [(1, 2, 3, 4, 5, 6)],
+ ["f1", "f2", "label", "pred", "model_version", "ts"],
+ )
+ self.assertEqual(
+ df.melt(
+ "model_version",
+ ["label", "f2"],
+ "f1",
+ "f2",
+ )
+ .groupby("f1")
+ .count()
+ .count(),
+ 2,
+ )
+
def test_observe(self):
# SPARK-36263: tests the DataFrame.observe(Observation, *Column) method
from pyspark.sql import Observation
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 94f6d334626..a57fd7a31d3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -765,7 +765,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
case p: Pivot if !p.childrenResolved || !p.aggregates.forall(_.resolved)
|| (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved))
|| !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p
- case p @ Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) =>
+ case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) =>
if (!RowOrdering.isOrderable(pivotColumn.dataType)) {
throw QueryCompilationErrors.unorderablePivotColError(pivotColumn)
}
@@ -829,9 +829,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
Alias(ExtractValue(pivotAtt, Literal(i), resolver), outputName(value, aggregate))()
}
}
- val newProject = Project(groupByExprsAttr ++ pivotOutputs, secondAgg)
- newProject.copyTagsFrom(p)
- newProject
+ Project(groupByExprsAttr ++ pivotOutputs, secondAgg)
} else {
val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value =>
def ifExpr(e: Expression) = {
@@ -865,9 +863,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
Alias(filteredAggregate, outputName(value, aggregate))()
}
}
- val newAggregate = Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child)
- newAggregate.copyTagsFrom(p)
- newAggregate
+ Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child)
}
}
@@ -3264,9 +3260,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
// Finally, generate output columns according to the original projectList.
val finalProjectList = aggregateExprs.map(_.toAttribute)
- val newProject = Project(finalProjectList, withWindow)
- newProject.copyTagsFrom(f)
- newProject
+ Project(finalProjectList, withWindow)
case p: LogicalPlan if !p.childrenResolved => p
@@ -3284,9 +3278,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
// Finally, generate output columns according to the original projectList.
val finalProjectList = aggregateExprs.map(_.toAttribute)
- val newProject = Project(finalProjectList, withWindow)
- newProject.copyTagsFrom(a)
- newProject
+ Project(finalProjectList, withWindow)
// We only extract Window Expressions after all expressions of the Project
// have been resolved, and lateral column aliases are properly handled first.
@@ -3303,9 +3295,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
// Finally, generate output columns according to the original projectList.
val finalProjectList = projectList.map(_.toAttribute)
- val newProject = Project(finalProjectList, withWindow)
- newProject.copyTagsFrom(p)
- newProject
+ Project(finalProjectList, withWindow)
}
}
@@ -3461,14 +3451,20 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
_.containsPattern(NATURAL_LIKE_JOIN), ruleId) {
case j @ Join(left, right, UsingJoin(joinType, usingCols), _, hint)
if left.resolved && right.resolved && j.duplicateResolved =>
- commonNaturalJoinProcessing(left, right, joinType, usingCols, None, hint,
- j.getTagValue(LogicalPlan.PLAN_ID_TAG))
+ val project = commonNaturalJoinProcessing(
+ left, right, joinType, usingCols, None, hint)
+ j.getTagValue(LogicalPlan.PLAN_ID_TAG)
+ .foreach(project.setTagValue(LogicalPlan.PLAN_ID_TAG, _))
+ project
case j @ Join(left, right, NaturalJoin(joinType), condition, hint)
if j.resolvedExceptNatural =>
// find common column names from both sides
val joinNames = left.output.map(_.name).intersect(right.output.map(_.name))
- commonNaturalJoinProcessing(left, right, joinType, joinNames, condition, hint,
- j.getTagValue(LogicalPlan.PLAN_ID_TAG))
+ val project = commonNaturalJoinProcessing(
+ left, right, joinType, joinNames, condition, hint)
+ j.getTagValue(LogicalPlan.PLAN_ID_TAG)
+ .foreach(project.setTagValue(LogicalPlan.PLAN_ID_TAG, _))
+ project
}
}
@@ -3516,8 +3512,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
joinType: JoinType,
joinNames: Seq[String],
condition: Option[Expression],
- hint: JoinHint,
- planId: Option[Long] = None): LogicalPlan = {
+ hint: JoinHint): LogicalPlan = {
import org.apache.spark.sql.catalyst.util._
val leftKeys = joinNames.map { keyName =>
@@ -3570,13 +3565,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
throw QueryExecutionErrors.unsupportedNaturalJoinTypeError(joinType)
}
- val newJoin = Join(left, right, joinType, newCondition, hint)
- // retain the plan id used in Spark Connect
- planId.foreach(newJoin.setTagValue(LogicalPlan.PLAN_ID_TAG, _))
-
// use Project to hide duplicated common keys
// propagate hidden columns from nested USING/NATURAL JOINs
- val project = Project(projectList, newJoin)
+ val project = Project(projectList, Join(left, right, joinType, newCondition, hint))
project.setTagValue(
Project.hiddenOutputTag,
hiddenList.map(_.markAsQualifiedAccessOnly()) ++
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala
index 173c9d44a2b..2982d8477fc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala
@@ -149,12 +149,10 @@ object CTESubstitution extends Rule[LogicalPlan] {
plan: LogicalPlan,
cteDefs: ArrayBuffer[CTERelationDef]): LogicalPlan = {
plan.resolveOperatorsUp {
- case cte @ UnresolvedWith(child, relations) =>
+ case UnresolvedWith(child, relations) =>
val resolvedCTERelations =
resolveCTERelations(relations, isLegacy = true, forceInline = false, Seq.empty, cteDefs)
- val substituted = substituteCTE(child, alwaysInline = true, resolvedCTERelations)
- substituted.copyTagsFrom(cte)
- substituted
+ substituteCTE(child, alwaysInline = true, resolvedCTERelations)
}
}
@@ -204,7 +202,7 @@ object CTESubstitution extends Rule[LogicalPlan] {
var firstSubstituted: Option[LogicalPlan] = None
val newPlan = plan.resolveOperatorsDownWithPruning(
_.containsAnyPattern(UNRESOLVED_WITH, PLAN_EXPRESSION)) {
- case cte @ UnresolvedWith(child: LogicalPlan, relations) =>
+ case UnresolvedWith(child: LogicalPlan, relations) =>
val resolvedCTERelations =
resolveCTERelations(relations, isLegacy = false, forceInline, outerCTEDefs, cteDefs) ++
outerCTEDefs
@@ -215,7 +213,6 @@ object CTESubstitution extends Rule[LogicalPlan] {
if (firstSubstituted.isEmpty) {
firstSubstituted = Some(substituted)
}
- substituted.copyTagsFrom(cte)
substituted
case other =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index c5e98683c74..56e8843fda5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -266,10 +266,8 @@ abstract class TypeCoercionBase {
s -> Nil
} else {
assert(newChildren.length == 2)
- val newExcept = Except(newChildren.head, newChildren.last, isAll)
- newExcept.copyTagsFrom(s)
val attrMapping = left.output.zip(newChildren.head.output)
- newExcept -> attrMapping
+ Except(newChildren.head, newChildren.last, isAll) -> attrMapping
}
case s @ Intersect(left, right, isAll) if s.childrenResolved &&
@@ -279,10 +277,8 @@ abstract class TypeCoercionBase {
s -> Nil
} else {
assert(newChildren.length == 2)
- val newIntersect = Intersect(newChildren.head, newChildren.last, isAll)
- newIntersect.copyTagsFrom(s)
val attrMapping = left.output.zip(newChildren.head.output)
- newIntersect -> attrMapping
+ Intersect(newChildren.head, newChildren.last, isAll) -> attrMapping
}
case s: Union if s.childrenResolved && !s.byName &&
@@ -292,9 +288,7 @@ abstract class TypeCoercionBase {
s -> Nil
} else {
val attrMapping = s.children.head.output.zip(newChildren.head.output)
- val newUnion = s.copy(children = newChildren)
- newUnion.copyTagsFrom(s)
- newUnion -> attrMapping
+ s.copy(children = newChildren) -> attrMapping
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index ef7cd7401f2..2a62ea1feb0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -361,6 +361,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
} else {
transferAttrMapping ++ newOtherAttrMapping
}
+ if (!(plan eq planAfterRule)) {
+ planAfterRule.copyTagsFrom(plan)
+ }
planAfterRule -> resultAttrMapping.toSeq
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala
index 56f6b116759..45a20cbe3aa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala
@@ -177,10 +177,14 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan =>
self.markRuleAsIneffective(ruleId)
self
} else {
+ rewritten_plan.copyTagsFrom(self)
rewritten_plan
}
} else {
- afterRule.mapChildren(_.resolveOperatorsDownWithPruning(cond, ruleId)(rule))
+ val newPlan = afterRule
+ .mapChildren(_.resolveOperatorsDownWithPruning(cond, ruleId)(rule))
+ newPlan.copyTagsFrom(self)
+ newPlan
}
}
} else {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org