You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/02/13 08:22:42 UTC
[spark] branch master updated: [SPARK-41812][SPARK-41823][CONNECT][SQL][PYTHON] Resolve ambiguous columns issue in `Join`
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 167bbca49c1 [SPARK-41812][SPARK-41823][CONNECT][SQL][PYTHON] Resolve ambiguous columns issue in `Join`
167bbca49c1 is described below
commit 167bbca49c1c12ccd349d4330862c136b38d4522
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Mon Feb 13 16:22:06 2023 +0800
[SPARK-41812][SPARK-41823][CONNECT][SQL][PYTHON] Resolve ambiguous columns issue in `Join`
### What changes were proposed in this pull request?
In Python Client
- generate `plan_id` for each proto plan (It's up to the Client to guarantee the uniqueness);
- attach `plan_id` to the column created by `DataFrame[col_name]` or `DataFrame.col_name`;
- Note that `F.col(col_name)` doesn't have `plan_id`;
In Connect Planner:
- attach `plan_id` to `UnresolvedAttribute`s and `LogicalPlan `s via `TreeNodeTag`
In Analyzer:
- for an `UnresolvedAttribute` with `plan_id`, search the matching node in the plan, and resolve it with the found node if possible
**Out of scope:**
- resolve `self-join`
- add a `DetectAmbiguousSelfJoin`-like rule for detection
### Why are the changes needed?
Fix bug, before this PR:
```
df1.join(df2, df1["value"] == df2["value"]) <- fail due to can not resolve `value`
df1.join(df2, df1["value"] == df2["value"]).select(df1.value) <- fail due to can not resolve `value`
df1.select(df2.value) <- should fail, but run as `df1.select(df1.value)` and return the incorrect results
```
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
added tests, enabled tests
Closes #39925 from zhengruifeng/connect_plan_id.
Authored-by: Ruifeng Zheng <ru...@apache.org>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
.../main/protobuf/spark/connect/expressions.proto | 3 +
.../main/protobuf/spark/connect/relations.proto | 3 +
.../sql/connect/planner/SparkConnectPlanner.scala | 22 ++-
python/pyspark/sql/column.py | 5 +-
python/pyspark/sql/connect/dataframe.py | 15 +-
python/pyspark/sql/connect/expressions.py | 7 +-
python/pyspark/sql/connect/functions.py | 19 +-
python/pyspark/sql/connect/plan.py | 213 +++++++++++----------
.../pyspark/sql/connect/proto/expressions_pb2.py | 54 +++---
.../pyspark/sql/connect/proto/expressions_pb2.pyi | 20 +-
python/pyspark/sql/connect/proto/relations_pb2.py | 200 +++++++++----------
python/pyspark/sql/connect/proto/relations_pb2.pyi | 15 +-
.../sql/tests/connect/test_connect_basic.py | 62 ++++++
.../pyspark/sql/tests/connect/test_connect_plan.py | 20 +-
.../spark/sql/catalyst/analysis/Analyzer.scala | 16 +-
.../catalyst/analysis/ColumnResolutionHelper.scala | 58 +++++-
.../sql/catalyst/plans/logical/LogicalPlan.scala | 14 +-
17 files changed, 481 insertions(+), 265 deletions(-)
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 8682e1ee27b..1929d9cdca3 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -197,6 +197,9 @@ message Expression {
// (Required) An identifier that will be parsed by Catalyst parser. This should follow the
// Spark SQL identifier syntax.
string unparsed_identifier = 1;
+
+ // (Optional) The id of corresponding connect plan.
+ optional int64 plan_id = 2;
}
// An unresolved function is not explicitly bound to one explicit function, but the function
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index ea1216957d8..29fffd65c75 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -93,6 +93,9 @@ message Unknown {}
message RelationCommon {
// (Required) Shared relation metadata.
string source_info = 1;
+
+ // (Optional) A per-client globally unique id for a given connect plan.
+ optional int64 plan_id = 2;
}
// Relation that uses a SQL query to generate the output.
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 740d6b85964..53d494cdcb7 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -66,7 +66,7 @@ class SparkConnectPlanner(val session: SparkSession) {
// The root of the query plan is a relation and we apply the transformations to it.
def transformRelation(rel: proto.Relation): LogicalPlan = {
- rel.getRelTypeCase match {
+ val plan = rel.getRelTypeCase match {
// DataFrame API
case proto.Relation.RelTypeCase.SHOW_STRING => transformShowString(rel.getShowString)
case proto.Relation.RelTypeCase.READ => transformReadRel(rel.getRead)
@@ -124,6 +124,11 @@ class SparkConnectPlanner(val session: SparkSession) {
transformRelationPlugin(rel.getExtension)
case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.")
}
+
+ if (rel.hasCommon && rel.getCommon.hasPlanId) {
+ plan.setTagValue(LogicalPlan.PLAN_ID_TAG, rel.getCommon.getPlanId)
+ }
+ plan
}
private def transformRelationPlugin(extension: ProtoAny): LogicalPlan = {
@@ -702,10 +707,6 @@ class SparkConnectPlanner(val session: SparkSession) {
logical.Project(projectList = projection, child = baseRel)
}
- private def transformUnresolvedExpression(exp: proto.Expression): UnresolvedAttribute = {
- UnresolvedAttribute.quotedString(exp.getUnresolvedAttribute.getUnparsedIdentifier)
- }
-
/**
* Transforms an input protobuf expression into the Catalyst expression. This is usually not
* called directly. Typically the planner will traverse the expressions automatically, only
@@ -720,7 +721,7 @@ class SparkConnectPlanner(val session: SparkSession) {
exp.getExprTypeCase match {
case proto.Expression.ExprTypeCase.LITERAL => transformLiteral(exp.getLiteral)
case proto.Expression.ExprTypeCase.UNRESOLVED_ATTRIBUTE =>
- transformUnresolvedExpression(exp)
+ transformUnresolvedAttribute(exp.getUnresolvedAttribute)
case proto.Expression.ExprTypeCase.UNRESOLVED_FUNCTION =>
transformUnregisteredFunction(exp.getUnresolvedFunction)
.getOrElse(transformUnresolvedFunction(exp.getUnresolvedFunction))
@@ -758,6 +759,15 @@ class SparkConnectPlanner(val session: SparkSession) {
case expr => UnresolvedAlias(expr)
}
+ private def transformUnresolvedAttribute(
+ attr: proto.Expression.UnresolvedAttribute): UnresolvedAttribute = {
+ val expr = UnresolvedAttribute.quotedString(attr.getUnparsedIdentifier)
+ if (attr.hasPlanId) {
+ expr.setTagValue(LogicalPlan.PLAN_ID_TAG, attr.getPlanId)
+ }
+ expr
+ }
+
private def transformExpressionPlugin(extension: ProtoAny): Expression = {
SparkConnectPluginRegistry.expressionRegistry
// Lazily traverse the collection.
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index a790f191110..0b5f94cfaaa 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -279,7 +279,6 @@ class Column:
__ge__ = _bin_op("geq")
__gt__ = _bin_op("gt")
- # TODO(SPARK-41812): DataFrame.join: ambiguous column
_eqNullSafe_doc = """
Equality test that is safe for null values.
@@ -315,9 +314,9 @@ class Column:
... Row(value = 'bar'),
... Row(value = None)
... ])
- >>> df1.join(df2, df1["value"] == df2["value"]).count() # doctest: +SKIP
+ >>> df1.join(df2, df1["value"] == df2["value"]).count()
0
- >>> df1.join(df2, df1["value"].eqNullSafe(df2["value"])).count() # doctest: +SKIP
+ >>> df1.join(df2, df1["value"].eqNullSafe(df2["value"])).count()
1
>>> df2 = spark.createDataFrame([
... Row(id=1, value=float('NaN')),
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index 95e39f93dc0..667295e8667 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -50,12 +50,14 @@ from pyspark.sql.dataframe import (
)
from pyspark.errors import PySparkTypeError
+from pyspark.errors.exceptions.connect import SparkConnectException
import pyspark.sql.connect.plan as plan
from pyspark.sql.connect.group import GroupedData
from pyspark.sql.connect.readwriter import DataFrameWriter, DataFrameWriterV2
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.expressions import UnresolvedRegex
from pyspark.sql.connect.functions import (
+ _to_col_with_plan_id,
_to_col,
_invoke_function,
col,
@@ -1284,10 +1286,12 @@ class DataFrame:
if isinstance(item, str):
# Check for alias
alias = self._get_alias()
- if alias is not None:
- return col(alias)
- else:
- return col(item)
+ if self._plan is None:
+ raise SparkConnectException("Cannot analyze on empty plan.")
+ return _to_col_with_plan_id(
+ col=alias if alias is not None else item,
+ plan_id=self._plan._plan_id,
+ )
elif isinstance(item, Column):
return self.filter(item)
elif isinstance(item, (list, tuple)):
@@ -1694,9 +1698,8 @@ def _test() -> None:
del pyspark.sql.connect.dataframe.DataFrame.repartition.__doc__
del pyspark.sql.connect.dataframe.DataFrame.repartitionByRange.__doc__
- # TODO(SPARK-41823): ambiguous column names
+ # TODO(SPARK-42367): DataFrame.drop should handle duplicated columns
del pyspark.sql.connect.dataframe.DataFrame.drop.__doc__
- del pyspark.sql.connect.dataframe.DataFrame.join.__doc__
# TODO(SPARK-41625): Support Structured Streaming
del pyspark.sql.connect.dataframe.DataFrame.isStreaming.__doc__
diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py
index 28b796496ec..571dd2b2f4b 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -339,11 +339,14 @@ class ColumnReference(Expression):
treat it as an unresolved attribute. Attributes that have the same fully
qualified name are identical"""
- def __init__(self, unparsed_identifier: str) -> None:
+ def __init__(self, unparsed_identifier: str, plan_id: Optional[int] = None) -> None:
super().__init__()
assert isinstance(unparsed_identifier, str)
self._unparsed_identifier = unparsed_identifier
+ assert plan_id is None or isinstance(plan_id, int)
+ self._plan_id = plan_id
+
def name(self) -> str:
"""Returns the qualified name of the column reference."""
return self._unparsed_identifier
@@ -352,6 +355,8 @@ class ColumnReference(Expression):
"""Returns the Proto representation of the expression."""
expr = proto.Expression()
expr.unresolved_attribute.unparsed_identifier = self._unparsed_identifier
+ if self._plan_id is not None:
+ expr.unresolved_attribute.plan_id = self._plan_id
return expr
def __repr__(self) -> str:
diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py
index d4984b1ba67..e5305938797 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -63,6 +63,15 @@ if TYPE_CHECKING:
from pyspark.sql.connect.dataframe import DataFrame
+def _to_col_with_plan_id(col: str, plan_id: Optional[int]) -> Column:
+ if col == "*":
+ return Column(UnresolvedStar(unparsed_target=None))
+ elif col.endswith(".*"):
+ return Column(UnresolvedStar(unparsed_target=col))
+ else:
+ return Column(ColumnReference(unparsed_identifier=col, plan_id=plan_id))
+
+
def _to_col(col: "ColumnOrName") -> Column:
assert isinstance(col, (Column, str))
return col if isinstance(col, Column) else column(col)
@@ -202,12 +211,7 @@ def _options_to_col(options: Dict[str, Any]) -> Column:
def col(col: str) -> Column:
- if col == "*":
- return Column(UnresolvedStar(unparsed_target=None))
- elif col.endswith(".*"):
- return Column(UnresolvedStar(unparsed_target=col))
- else:
- return Column(ColumnReference(unparsed_identifier=col))
+ return _to_col_with_plan_id(col=col, plan_id=None)
col.__doc__ = pysparkfuncs.col.__doc__
@@ -2470,9 +2474,6 @@ def _test() -> None:
del pyspark.sql.connect.functions.timestamp_seconds.__doc__
del pyspark.sql.connect.functions.unix_timestamp.__doc__
- # TODO(SPARK-41812): Proper column names after join
- del pyspark.sql.connect.functions.count_distinct.__doc__
-
# TODO(SPARK-41843): Implement SparkSession.udf
del pyspark.sql.connect.functions.call_udf.__doc__
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index ced0e4008e1..d37201e4408 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -21,9 +21,11 @@ check_dependencies(__name__, __file__)
from typing import Any, List, Optional, Sequence, Union, cast, TYPE_CHECKING, Mapping, Dict
import functools
import json
-import pyarrow as pa
+from threading import Lock
from inspect import signature, isclass
+import pyarrow as pa
+
from pyspark.sql.types import DataType
import pyspark.sql.connect.proto as proto
@@ -40,13 +42,29 @@ class InputValidationError(Exception):
pass
-class LogicalPlan(object):
+class LogicalPlan:
+
+ _lock: Lock = Lock()
+ _nextPlanId: int = 0
INDENT = 2
def __init__(self, child: Optional["LogicalPlan"]) -> None:
self._child = child
+ plan_id: Optional[int] = None
+ with LogicalPlan._lock:
+ plan_id = LogicalPlan._nextPlanId
+ LogicalPlan._nextPlanId += 1
+
+ assert plan_id is not None
+ self._plan_id = plan_id
+
+ def _create_proto_relation(self) -> proto.Relation:
+ plan = proto.Relation()
+ plan.common.plan_id = self._plan_id
+ return plan
+
def unresolved_attr(self, colName: str) -> proto.Expression:
"""Creates an unresolved attribute from a column name."""
exp = proto.Expression()
@@ -258,7 +276,7 @@ class DataSource(LogicalPlan):
self._paths = paths
def plan(self, session: "SparkConnectClient") -> proto.Relation:
- plan = proto.Relation()
+ plan = self._create_proto_relation()
plan.read.data_source.format = self._format
if self._schema is not None:
plan.read.data_source.schema = self._schema
@@ -276,7 +294,7 @@ class Read(LogicalPlan):
self.table_name = table_name
def plan(self, session: "SparkConnectClient") -> proto.Relation:
- plan = proto.Relation()
+ plan = self._create_proto_relation()
plan.read.named_table.unparsed_identifier = self.table_name
return plan
@@ -306,8 +324,7 @@ class LocalRelation(LogicalPlan):
self._schema = schema
def plan(self, session: "SparkConnectClient") -> proto.Relation:
- plan = proto.Relation()
-
+ plan = self._create_proto_relation()
if self._table is not None:
sink = pa.BufferOutputStream()
with pa.ipc.new_stream(sink, self._table.schema) as writer:
@@ -341,7 +358,7 @@ class ShowString(LogicalPlan):
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
- plan = proto.Relation()
+ plan = self._create_proto_relation()
plan.show_string.input.CopyFrom(self._child.plan(session))
plan.show_string.num_rows = self.num_rows
plan.show_string.truncate = self.truncate
@@ -378,6 +395,8 @@ class Project(LogicalPlan):
from pyspark.sql.connect.functions import col
assert self._child is not None
+ plan = self._create_proto_relation()
+ plan.project.input.CopyFrom(self._child.plan(session))
proj_exprs = []
for c in self._columns:
@@ -386,8 +405,6 @@ class Project(LogicalPlan):
else:
proj_exprs.append(col(c).to_plan(session))
- plan = proto.Relation()
- plan.project.input.CopyFrom(self._child.plan(session))
plan.project.expressions.extend(proj_exprs)
return plan
@@ -426,7 +443,7 @@ class WithColumns(LogicalPlan):
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
- plan = proto.Relation()
+ plan = self._create_proto_relation()
plan.with_columns.input.CopyFrom(self._child.plan(session))
for i in range(0, len(self._columnNames)):
@@ -461,7 +478,7 @@ class Hint(LogicalPlan):
from pyspark.sql.connect.functions import array, lit
assert self._child is not None
- plan = proto.Relation()
+ plan = self._create_proto_relation()
plan.hint.input.CopyFrom(self._child.plan(session))
plan.hint.name = self._name
for param in self._parameters:
@@ -479,7 +496,7 @@ class Filter(LogicalPlan):
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
- plan = proto.Relation()
+ plan = self._create_proto_relation()
plan.filter.input.CopyFrom(self._child.plan(session))
plan.filter.condition.CopyFrom(self.filter.to_plan(session))
return plan
@@ -492,7 +509,7 @@ class Limit(LogicalPlan):
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
- plan = proto.Relation()
+ plan = self._create_proto_relation()
plan.limit.input.CopyFrom(self._child.plan(session))
plan.limit.limit = self.limit
return plan
@@ -505,7 +522,7 @@ class Tail(LogicalPlan):
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
- plan = proto.Relation()
+ plan = self._create_proto_relation()
plan.tail.input.CopyFrom(self._child.plan(session))
plan.tail.limit = self.limit
return plan
@@ -518,7 +535,7 @@ class Offset(LogicalPlan):
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
- plan = proto.Relation()
+ plan = self._create_proto_relation()
plan.offset.input.CopyFrom(self._child.plan(session))
plan.offset.offset = self.offset
return plan
@@ -537,7 +554,7 @@ class Deduplicate(LogicalPlan):
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
- plan = proto.Relation()
+ plan = self._create_proto_relation()
plan.deduplicate.input.CopyFrom(self._child.plan(session))
plan.deduplicate.all_columns_as_keys = self.all_columns_as_keys
if self.column_names is not None:
@@ -570,7 +587,7 @@ class Sort(LogicalPlan):
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
- plan = proto.Relation()
+ plan = self._create_proto_relation()
plan.sort.input.CopyFrom(self._child.plan(session))
plan.sort.order.extend([self._convert_col(c, session) for c in self.columns])
plan.sort.is_global = self.is_global
@@ -599,7 +616,7 @@ class Drop(LogicalPlan):
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
- plan = proto.Relation()
+ plan = self._create_proto_relation()
plan.drop.input.CopyFrom(self._child.plan(session))
plan.drop.cols.extend([self._convert_to_expr(c, session) for c in self.columns])
return plan
@@ -624,7 +641,7 @@ class Sample(LogicalPlan):
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
- plan = proto.Relation()
+ plan = self._create_proto_relation()
plan.sample.input.CopyFrom(self._child.plan(session))
plan.sample.lower_bound = self.lower_bound
plan.sample.upper_bound = self.upper_bound
@@ -672,32 +689,31 @@ class Aggregate(LogicalPlan):
from pyspark.sql.connect.functions import lit
assert self._child is not None
-
- agg = proto.Relation()
-
- agg.aggregate.input.CopyFrom(self._child.plan(session))
-
- agg.aggregate.grouping_expressions.extend([c.to_plan(session) for c in self._grouping_cols])
- agg.aggregate.aggregate_expressions.extend(
+ plan = self._create_proto_relation()
+ plan.aggregate.input.CopyFrom(self._child.plan(session))
+ plan.aggregate.grouping_expressions.extend(
+ [c.to_plan(session) for c in self._grouping_cols]
+ )
+ plan.aggregate.aggregate_expressions.extend(
[c.to_plan(session) for c in self._aggregate_cols]
)
if self._group_type == "groupby":
- agg.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY
+ plan.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY
elif self._group_type == "rollup":
- agg.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP
+ plan.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP
elif self._group_type == "cube":
- agg.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_CUBE
+ plan.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_CUBE
elif self._group_type == "pivot":
- agg.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_PIVOT
+ plan.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_PIVOT
assert self._pivot_col is not None
- agg.aggregate.pivot.col.CopyFrom(self._pivot_col.to_plan(session))
+ plan.aggregate.pivot.col.CopyFrom(self._pivot_col.to_plan(session))
if self._pivot_values is not None and len(self._pivot_values) > 0:
- agg.aggregate.pivot.values.extend(
+ plan.aggregate.pivot.values.extend(
[lit(v).to_plan(session).literal for v in self._pivot_values]
)
- return agg
+ return plan
class Join(LogicalPlan):
@@ -742,23 +758,23 @@ class Join(LogicalPlan):
self.how = join_type
def plan(self, session: "SparkConnectClient") -> proto.Relation:
- rel = proto.Relation()
- rel.join.left.CopyFrom(self.left.plan(session))
- rel.join.right.CopyFrom(self.right.plan(session))
+ plan = self._create_proto_relation()
+ plan.join.left.CopyFrom(self.left.plan(session))
+ plan.join.right.CopyFrom(self.right.plan(session))
if self.on is not None:
if not isinstance(self.on, list):
if isinstance(self.on, str):
- rel.join.using_columns.append(self.on)
+ plan.join.using_columns.append(self.on)
else:
- rel.join.join_condition.CopyFrom(self.to_attr_or_expression(self.on, session))
+ plan.join.join_condition.CopyFrom(self.to_attr_or_expression(self.on, session))
elif len(self.on) > 0:
if isinstance(self.on[0], str):
- rel.join.using_columns.extend(cast(str, self.on))
+ plan.join.using_columns.extend(cast(str, self.on))
else:
merge_column = functools.reduce(lambda c1, c2: c1 & c2, self.on)
- rel.join.join_condition.CopyFrom(cast(Column, merge_column).to_plan(session))
- rel.join.join_type = self.how
- return rel
+ plan.join.join_condition.CopyFrom(cast(Column, merge_column).to_plan(session))
+ plan.join.join_type = self.how
+ return plan
def print(self, indent: int = 0) -> str:
i = " " * indent
@@ -800,29 +816,29 @@ class SetOperation(LogicalPlan):
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
- rel = proto.Relation()
+ plan = self._create_proto_relation()
if self._child is not None:
- rel.set_op.left_input.CopyFrom(self._child.plan(session))
+ plan.set_op.left_input.CopyFrom(self._child.plan(session))
if self.other is not None:
- rel.set_op.right_input.CopyFrom(self.other.plan(session))
+ plan.set_op.right_input.CopyFrom(self.other.plan(session))
if self.set_op == "union":
- rel.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_UNION
+ plan.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_UNION
elif self.set_op == "intersect":
- rel.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_INTERSECT
+ plan.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_INTERSECT
elif self.set_op == "except":
- rel.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_EXCEPT
+ plan.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_EXCEPT
else:
raise NotImplementedError(
"""
Unsupported set operation type: %s.
"""
- % rel.set_op.set_op_type
+ % plan.set_op.set_op_type
)
- rel.set_op.is_all = self.is_all
- rel.set_op.by_name = self.by_name
- rel.set_op.allow_missing_columns = self.allow_missing_columns
- return rel
+ plan.set_op.is_all = self.is_all
+ plan.set_op.by_name = self.by_name
+ plan.set_op.allow_missing_columns = self.allow_missing_columns
+ return plan
def print(self, indent: int = 0) -> str:
assert self._child is not None
@@ -860,12 +876,12 @@ class Repartition(LogicalPlan):
self._shuffle = shuffle
def plan(self, session: "SparkConnectClient") -> proto.Relation:
- rel = proto.Relation()
+ plan = self._create_proto_relation()
if self._child is not None:
- rel.repartition.input.CopyFrom(self._child.plan(session))
- rel.repartition.shuffle = self._shuffle
- rel.repartition.num_partitions = self._num_partitions
- return rel
+ plan.repartition.input.CopyFrom(self._child.plan(session))
+ plan.repartition.shuffle = self._shuffle
+ plan.repartition.num_partitions = self._num_partitions
+ return plan
class RepartitionByExpression(LogicalPlan):
@@ -882,7 +898,7 @@ class RepartitionByExpression(LogicalPlan):
self.columns = columns
def plan(self, session: "SparkConnectClient") -> proto.Relation:
- rel = proto.Relation()
+ plan = self._create_proto_relation()
part_exprs = []
for c in self.columns:
@@ -894,13 +910,13 @@ class RepartitionByExpression(LogicalPlan):
part_exprs.append(exp)
else:
part_exprs.append(self.unresolved_attr(c))
- rel.repartition_by_expression.partition_exprs.extend(part_exprs)
+ plan.repartition_by_expression.partition_exprs.extend(part_exprs)
if self._child is not None:
- rel.repartition_by_expression.input.CopyFrom(self._child.plan(session))
+ plan.repartition_by_expression.input.CopyFrom(self._child.plan(session))
if self.num_partitions is not None:
- rel.repartition_by_expression.num_partitions = self.num_partitions
- return rel
+ plan.repartition_by_expression.num_partitions = self.num_partitions
+ return plan
class SubqueryAlias(LogicalPlan):
@@ -911,11 +927,11 @@ class SubqueryAlias(LogicalPlan):
self._alias = alias
def plan(self, session: "SparkConnectClient") -> proto.Relation:
- rel = proto.Relation()
+ plan = self._create_proto_relation()
if self._child is not None:
- rel.subquery_alias.input.CopyFrom(self._child.plan(session))
- rel.subquery_alias.alias = self._alias
- return rel
+ plan.subquery_alias.input.CopyFrom(self._child.plan(session))
+ plan.subquery_alias.alias = self._alias
+ return plan
class SQL(LogicalPlan):
@@ -931,14 +947,14 @@ class SQL(LogicalPlan):
self._args = args
def plan(self, session: "SparkConnectClient") -> proto.Relation:
- rel = proto.Relation()
- rel.sql.query = self._query
+ plan = self._create_proto_relation()
+ plan.sql.query = self._query
if self._args is not None and len(self._args) > 0:
for k, v in self._args.items():
- rel.sql.args[k] = v
+ plan.sql.args[k] = v
- return rel
+ return plan
class Range(LogicalPlan):
@@ -956,13 +972,13 @@ class Range(LogicalPlan):
self._num_partitions = num_partitions
def plan(self, session: "SparkConnectClient") -> proto.Relation:
- rel = proto.Relation()
- rel.range.start = self._start
- rel.range.end = self._end
- rel.range.step = self._step
+ plan = self._create_proto_relation()
+ plan.range.start = self._start
+ plan.range.end = self._end
+ plan.range.step = self._step
if self._num_partitions is not None:
- rel.range.num_partitions = self._num_partitions
- return rel
+ plan.range.num_partitions = self._num_partitions
+ return plan
class ToSchema(LogicalPlan):
@@ -972,8 +988,7 @@ class ToSchema(LogicalPlan):
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
-
- plan = proto.Relation()
+ plan = self._create_proto_relation()
plan.to_schema.input.CopyFrom(self._child.plan(session))
plan.to_schema.schema.CopyFrom(pyspark_types_to_proto_types(self._schema))
return plan
@@ -986,8 +1001,7 @@ class WithColumnsRenamed(LogicalPlan):
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
-
- plan = proto.Relation()
+ plan = self._create_proto_relation()
plan.with_columns_renamed.input.CopyFrom(self._child.plan(session))
for k, v in self._colsMap.items():
plan.with_columns_renamed.rename_columns_map[k] = v
@@ -1019,8 +1033,7 @@ class Unpivot(LogicalPlan):
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
-
- plan = proto.Relation()
+ plan = self._create_proto_relation()
plan.unpivot.input.CopyFrom(self._child.plan(session))
plan.unpivot.ids.extend([self.col_to_expr(x, session) for x in self.ids])
if self.values is not None:
@@ -1064,7 +1077,7 @@ class NAFill(LogicalPlan):
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
- plan = proto.Relation()
+ plan = self._create_proto_relation()
plan.fill_na.input.CopyFrom(self._child.plan(session))
if self.cols is not None and len(self.cols) > 0:
plan.fill_na.cols.extend(self.cols)
@@ -1086,7 +1099,7 @@ class NADrop(LogicalPlan):
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
- plan = proto.Relation()
+ plan = self._create_proto_relation()
plan.drop_na.input.CopyFrom(self._child.plan(session))
if self.cols is not None and len(self.cols) > 0:
plan.drop_na.cols.extend(self.cols)
@@ -1122,7 +1135,7 @@ class NAReplace(LogicalPlan):
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
- plan = proto.Relation()
+ plan = self._create_proto_relation()
plan.replace.input.CopyFrom(self._child.plan(session))
if self.cols is not None and len(self.cols) > 0:
plan.replace.cols.extend(self.cols)
@@ -1150,7 +1163,7 @@ class StatSummary(LogicalPlan):
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
- plan = proto.Relation()
+ plan = self._create_proto_relation()
plan.summary.input.CopyFrom(self._child.plan(session))
plan.summary.statistics.extend(self.statistics)
return plan
@@ -1163,7 +1176,7 @@ class StatDescribe(LogicalPlan):
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
- plan = proto.Relation()
+ plan = self._create_proto_relation()
plan.describe.input.CopyFrom(self._child.plan(session))
plan.describe.cols.extend(self.cols)
return plan
@@ -1177,8 +1190,7 @@ class StatCov(LogicalPlan):
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
-
- plan = proto.Relation()
+ plan = self._create_proto_relation()
plan.cov.input.CopyFrom(self._child.plan(session))
plan.cov.col1 = self._col1
plan.cov.col2 = self._col2
@@ -1200,8 +1212,7 @@ class StatApproxQuantile(LogicalPlan):
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
-
- plan = proto.Relation()
+ plan = self._create_proto_relation()
plan.approx_quantile.input.CopyFrom(self._child.plan(session))
plan.approx_quantile.cols.extend(self._cols)
plan.approx_quantile.probabilities.extend(self._probabilities)
@@ -1217,8 +1228,7 @@ class StatCrosstab(LogicalPlan):
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
-
- plan = proto.Relation()
+ plan = self._create_proto_relation()
plan.crosstab.input.CopyFrom(self._child.plan(session))
plan.crosstab.col1 = self.col1
plan.crosstab.col2 = self.col2
@@ -1238,8 +1248,7 @@ class StatFreqItems(LogicalPlan):
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
-
- plan = proto.Relation()
+ plan = self._create_proto_relation()
plan.freq_items.input.CopyFrom(self._child.plan(session))
plan.freq_items.cols.extend(self._cols)
plan.freq_items.support = self._support
@@ -1275,8 +1284,7 @@ class StatSampleBy(LogicalPlan):
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
-
- plan = proto.Relation()
+ plan = self._create_proto_relation()
plan.sample_by.input.CopyFrom(self._child.plan(session))
plan.sample_by.col.CopyFrom(self._col._expr.to_plan(session))
if len(self._fractions) > 0:
@@ -1299,8 +1307,7 @@ class StatCorr(LogicalPlan):
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
-
- plan = proto.Relation()
+ plan = self._create_proto_relation()
plan.corr.input.CopyFrom(self._child.plan(session))
plan.corr.col1 = self._col1
plan.corr.col2 = self._col2
@@ -1315,8 +1322,7 @@ class ToDF(LogicalPlan):
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
-
- plan = proto.Relation()
+ plan = self._create_proto_relation()
plan.to_df.input.CopyFrom(self._child.plan(session))
plan.to_df.column_names.extend(self._cols)
return plan
@@ -1333,8 +1339,8 @@ class CreateView(LogicalPlan):
def command(self, session: "SparkConnectClient") -> proto.Command:
assert self._child is not None
-
plan = proto.Command()
+
plan.create_dataframe_view.replace = self._replace
plan.create_dataframe_view.is_global = self._is_global
plan.create_dataframe_view.name = self._name
@@ -1358,6 +1364,7 @@ class WriteOperation(LogicalPlan):
def command(self, session: "SparkConnectClient") -> proto.Command:
assert self._child is not None
plan = proto.Command()
+
plan.write_operation.input.CopyFrom(self._child.plan(session))
if self.source is not None:
plan.write_operation.source = self.source
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 92d9e6a610a..891be5ea9ea 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x92%\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...]
+ b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xbc%\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...]
)
@@ -300,7 +300,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001"
_EXPRESSION._serialized_start = 105
- _EXPRESSION._serialized_end = 4859
+ _EXPRESSION._serialized_end = 4901
_EXPRESSION_WINDOW._serialized_start = 1475
_EXPRESSION_WINDOW._serialized_end = 2258
_EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1765
@@ -324,29 +324,29 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3599
_EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3697
_EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3715
- _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 3785
- _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 3788
- _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 3992
- _EXPRESSION_EXPRESSIONSTRING._serialized_start = 3994
- _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4044
- _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4046
- _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4128
- _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4130
- _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4174
- _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4177
- _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4309
- _EXPRESSION_UPDATEFIELDS._serialized_start = 4312
- _EXPRESSION_UPDATEFIELDS._serialized_end = 4499
- _EXPRESSION_ALIAS._serialized_start = 4501
- _EXPRESSION_ALIAS._serialized_end = 4621
- _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4624
- _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4782
- _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4784
- _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4846
- _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 4862
- _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5173
- _PYTHONUDF._serialized_start = 5176
- _PYTHONUDF._serialized_end = 5306
- _SCALARSCALAUDF._serialized_start = 5309
- _SCALARSCALAUDF._serialized_end = 5493
+ _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 3827
+ _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 3830
+ _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 4034
+ _EXPRESSION_EXPRESSIONSTRING._serialized_start = 4036
+ _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4086
+ _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4088
+ _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4170
+ _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4172
+ _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4216
+ _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4219
+ _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4351
+ _EXPRESSION_UPDATEFIELDS._serialized_start = 4354
+ _EXPRESSION_UPDATEFIELDS._serialized_end = 4541
+ _EXPRESSION_ALIAS._serialized_start = 4543
+ _EXPRESSION_ALIAS._serialized_end = 4663
+ _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4666
+ _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4824
+ _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4826
+ _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4888
+ _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 4904
+ _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5215
+ _PYTHONUDF._serialized_start = 5218
+ _PYTHONUDF._serialized_end = 5348
+ _SCALARSCALAUDF._serialized_start = 5351
+ _SCALARSCALAUDF._serialized_end = 5535
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 934e0016c90..88b1fd8ef7e 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -613,19 +613,37 @@ class Expression(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
UNPARSED_IDENTIFIER_FIELD_NUMBER: builtins.int
+ PLAN_ID_FIELD_NUMBER: builtins.int
unparsed_identifier: builtins.str
"""(Required) An identifier that will be parsed by Catalyst parser. This should follow the
Spark SQL identifier syntax.
"""
+ plan_id: builtins.int
+ """(Optional) The id of corresponding connect plan."""
def __init__(
self,
*,
unparsed_identifier: builtins.str = ...,
+ plan_id: builtins.int | None = ...,
) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal["_plan_id", b"_plan_id", "plan_id", b"plan_id"],
+ ) -> builtins.bool: ...
def ClearField(
self,
- field_name: typing_extensions.Literal["unparsed_identifier", b"unparsed_identifier"],
+ field_name: typing_extensions.Literal[
+ "_plan_id",
+ b"_plan_id",
+ "plan_id",
+ b"plan_id",
+ "unparsed_identifier",
+ b"unparsed_identifier",
+ ],
) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_plan_id", b"_plan_id"]
+ ) -> typing_extensions.Literal["plan_id"] | None: ...
class UnresolvedFunction(google.protobuf.message.Message):
"""An unresolved function is not explicitly bound to one explicit function, but the function
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py
index ece5920953e..057b96a8da9 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as spark_dot_connect_dot_catal
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf9\x11\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+ b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf9\x11\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
)
@@ -639,103 +639,103 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_UNKNOWN._serialized_start = 2464
_UNKNOWN._serialized_end = 2473
_RELATIONCOMMON._serialized_start = 2475
- _RELATIONCOMMON._serialized_end = 2524
- _SQL._serialized_start = 2527
- _SQL._serialized_end = 2661
- _SQL_ARGSENTRY._serialized_start = 2606
- _SQL_ARGSENTRY._serialized_end = 2661
- _READ._serialized_start = 2664
- _READ._serialized_end = 3112
- _READ_NAMEDTABLE._serialized_start = 2806
- _READ_NAMEDTABLE._serialized_end = 2867
- _READ_DATASOURCE._serialized_start = 2870
- _READ_DATASOURCE._serialized_end = 3099
- _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3030
- _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3088
- _PROJECT._serialized_start = 3114
- _PROJECT._serialized_end = 3231
- _FILTER._serialized_start = 3233
- _FILTER._serialized_end = 3345
- _JOIN._serialized_start = 3348
- _JOIN._serialized_end = 3819
- _JOIN_JOINTYPE._serialized_start = 3611
- _JOIN_JOINTYPE._serialized_end = 3819
- _SETOPERATION._serialized_start = 3822
- _SETOPERATION._serialized_end = 4301
- _SETOPERATION_SETOPTYPE._serialized_start = 4138
- _SETOPERATION_SETOPTYPE._serialized_end = 4252
- _LIMIT._serialized_start = 4303
- _LIMIT._serialized_end = 4379
- _OFFSET._serialized_start = 4381
- _OFFSET._serialized_end = 4460
- _TAIL._serialized_start = 4462
- _TAIL._serialized_end = 4537
- _AGGREGATE._serialized_start = 4540
- _AGGREGATE._serialized_end = 5122
- _AGGREGATE_PIVOT._serialized_start = 4879
- _AGGREGATE_PIVOT._serialized_end = 4990
- _AGGREGATE_GROUPTYPE._serialized_start = 4993
- _AGGREGATE_GROUPTYPE._serialized_end = 5122
- _SORT._serialized_start = 5125
- _SORT._serialized_end = 5285
- _DROP._serialized_start = 5287
- _DROP._serialized_end = 5387
- _DEDUPLICATE._serialized_start = 5390
- _DEDUPLICATE._serialized_end = 5561
- _LOCALRELATION._serialized_start = 5563
- _LOCALRELATION._serialized_end = 5652
- _SAMPLE._serialized_start = 5655
- _SAMPLE._serialized_end = 5928
- _RANGE._serialized_start = 5931
- _RANGE._serialized_end = 6076
- _SUBQUERYALIAS._serialized_start = 6078
- _SUBQUERYALIAS._serialized_end = 6192
- _REPARTITION._serialized_start = 6195
- _REPARTITION._serialized_end = 6337
- _SHOWSTRING._serialized_start = 6340
- _SHOWSTRING._serialized_end = 6482
- _STATSUMMARY._serialized_start = 6484
- _STATSUMMARY._serialized_end = 6576
- _STATDESCRIBE._serialized_start = 6578
- _STATDESCRIBE._serialized_end = 6659
- _STATCROSSTAB._serialized_start = 6661
- _STATCROSSTAB._serialized_end = 6762
- _STATCOV._serialized_start = 6764
- _STATCOV._serialized_end = 6860
- _STATCORR._serialized_start = 6863
- _STATCORR._serialized_end = 7000
- _STATAPPROXQUANTILE._serialized_start = 7003
- _STATAPPROXQUANTILE._serialized_end = 7167
- _STATFREQITEMS._serialized_start = 7169
- _STATFREQITEMS._serialized_end = 7294
- _STATSAMPLEBY._serialized_start = 7297
- _STATSAMPLEBY._serialized_end = 7606
- _STATSAMPLEBY_FRACTION._serialized_start = 7498
- _STATSAMPLEBY_FRACTION._serialized_end = 7597
- _NAFILL._serialized_start = 7609
- _NAFILL._serialized_end = 7743
- _NADROP._serialized_start = 7746
- _NADROP._serialized_end = 7880
- _NAREPLACE._serialized_start = 7883
- _NAREPLACE._serialized_end = 8179
- _NAREPLACE_REPLACEMENT._serialized_start = 8038
- _NAREPLACE_REPLACEMENT._serialized_end = 8179
- _TODF._serialized_start = 8181
- _TODF._serialized_end = 8269
- _WITHCOLUMNSRENAMED._serialized_start = 8272
- _WITHCOLUMNSRENAMED._serialized_end = 8511
- _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8444
- _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8511
- _WITHCOLUMNS._serialized_start = 8513
- _WITHCOLUMNS._serialized_end = 8632
- _HINT._serialized_start = 8635
- _HINT._serialized_end = 8767
- _UNPIVOT._serialized_start = 8770
- _UNPIVOT._serialized_end = 9097
- _UNPIVOT_VALUES._serialized_start = 9027
- _UNPIVOT_VALUES._serialized_end = 9086
- _TOSCHEMA._serialized_start = 9099
- _TOSCHEMA._serialized_end = 9205
- _REPARTITIONBYEXPRESSION._serialized_start = 9208
- _REPARTITIONBYEXPRESSION._serialized_end = 9411
+ _RELATIONCOMMON._serialized_end = 2566
+ _SQL._serialized_start = 2569
+ _SQL._serialized_end = 2703
+ _SQL_ARGSENTRY._serialized_start = 2648
+ _SQL_ARGSENTRY._serialized_end = 2703
+ _READ._serialized_start = 2706
+ _READ._serialized_end = 3154
+ _READ_NAMEDTABLE._serialized_start = 2848
+ _READ_NAMEDTABLE._serialized_end = 2909
+ _READ_DATASOURCE._serialized_start = 2912
+ _READ_DATASOURCE._serialized_end = 3141
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3072
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3130
+ _PROJECT._serialized_start = 3156
+ _PROJECT._serialized_end = 3273
+ _FILTER._serialized_start = 3275
+ _FILTER._serialized_end = 3387
+ _JOIN._serialized_start = 3390
+ _JOIN._serialized_end = 3861
+ _JOIN_JOINTYPE._serialized_start = 3653
+ _JOIN_JOINTYPE._serialized_end = 3861
+ _SETOPERATION._serialized_start = 3864
+ _SETOPERATION._serialized_end = 4343
+ _SETOPERATION_SETOPTYPE._serialized_start = 4180
+ _SETOPERATION_SETOPTYPE._serialized_end = 4294
+ _LIMIT._serialized_start = 4345
+ _LIMIT._serialized_end = 4421
+ _OFFSET._serialized_start = 4423
+ _OFFSET._serialized_end = 4502
+ _TAIL._serialized_start = 4504
+ _TAIL._serialized_end = 4579
+ _AGGREGATE._serialized_start = 4582
+ _AGGREGATE._serialized_end = 5164
+ _AGGREGATE_PIVOT._serialized_start = 4921
+ _AGGREGATE_PIVOT._serialized_end = 5032
+ _AGGREGATE_GROUPTYPE._serialized_start = 5035
+ _AGGREGATE_GROUPTYPE._serialized_end = 5164
+ _SORT._serialized_start = 5167
+ _SORT._serialized_end = 5327
+ _DROP._serialized_start = 5329
+ _DROP._serialized_end = 5429
+ _DEDUPLICATE._serialized_start = 5432
+ _DEDUPLICATE._serialized_end = 5603
+ _LOCALRELATION._serialized_start = 5605
+ _LOCALRELATION._serialized_end = 5694
+ _SAMPLE._serialized_start = 5697
+ _SAMPLE._serialized_end = 5970
+ _RANGE._serialized_start = 5973
+ _RANGE._serialized_end = 6118
+ _SUBQUERYALIAS._serialized_start = 6120
+ _SUBQUERYALIAS._serialized_end = 6234
+ _REPARTITION._serialized_start = 6237
+ _REPARTITION._serialized_end = 6379
+ _SHOWSTRING._serialized_start = 6382
+ _SHOWSTRING._serialized_end = 6524
+ _STATSUMMARY._serialized_start = 6526
+ _STATSUMMARY._serialized_end = 6618
+ _STATDESCRIBE._serialized_start = 6620
+ _STATDESCRIBE._serialized_end = 6701
+ _STATCROSSTAB._serialized_start = 6703
+ _STATCROSSTAB._serialized_end = 6804
+ _STATCOV._serialized_start = 6806
+ _STATCOV._serialized_end = 6902
+ _STATCORR._serialized_start = 6905
+ _STATCORR._serialized_end = 7042
+ _STATAPPROXQUANTILE._serialized_start = 7045
+ _STATAPPROXQUANTILE._serialized_end = 7209
+ _STATFREQITEMS._serialized_start = 7211
+ _STATFREQITEMS._serialized_end = 7336
+ _STATSAMPLEBY._serialized_start = 7339
+ _STATSAMPLEBY._serialized_end = 7648
+ _STATSAMPLEBY_FRACTION._serialized_start = 7540
+ _STATSAMPLEBY_FRACTION._serialized_end = 7639
+ _NAFILL._serialized_start = 7651
+ _NAFILL._serialized_end = 7785
+ _NADROP._serialized_start = 7788
+ _NADROP._serialized_end = 7922
+ _NAREPLACE._serialized_start = 7925
+ _NAREPLACE._serialized_end = 8221
+ _NAREPLACE_REPLACEMENT._serialized_start = 8080
+ _NAREPLACE_REPLACEMENT._serialized_end = 8221
+ _TODF._serialized_start = 8223
+ _TODF._serialized_end = 8311
+ _WITHCOLUMNSRENAMED._serialized_start = 8314
+ _WITHCOLUMNSRENAMED._serialized_end = 8553
+ _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8486
+ _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8553
+ _WITHCOLUMNS._serialized_start = 8555
+ _WITHCOLUMNS._serialized_end = 8674
+ _HINT._serialized_start = 8677
+ _HINT._serialized_end = 8809
+ _UNPIVOT._serialized_start = 8812
+ _UNPIVOT._serialized_end = 9139
+ _UNPIVOT_VALUES._serialized_start = 9069
+ _UNPIVOT_VALUES._serialized_end = 9128
+ _TOSCHEMA._serialized_start = 9141
+ _TOSCHEMA._serialized_end = 9247
+ _REPARTITIONBYEXPRESSION._serialized_start = 9250
+ _REPARTITIONBYEXPRESSION._serialized_end = 9453
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 41962ee4062..b7cef7b299d 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -478,16 +478,29 @@ class RelationCommon(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
SOURCE_INFO_FIELD_NUMBER: builtins.int
+ PLAN_ID_FIELD_NUMBER: builtins.int
source_info: builtins.str
"""(Required) Shared relation metadata."""
+ plan_id: builtins.int
+ """(Optional) A per-client globally unique id for a given connect plan."""
def __init__(
self,
*,
source_info: builtins.str = ...,
+ plan_id: builtins.int | None = ...,
) -> None: ...
+ def HasField(
+ self, field_name: typing_extensions.Literal["_plan_id", b"_plan_id", "plan_id", b"plan_id"]
+ ) -> builtins.bool: ...
def ClearField(
- self, field_name: typing_extensions.Literal["source_info", b"source_info"]
+ self,
+ field_name: typing_extensions.Literal[
+ "_plan_id", b"_plan_id", "plan_id", b"plan_id", "source_info", b"source_info"
+ ],
) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_plan_id", b"_plan_id"]
+ ) -> typing_extensions.Literal["plan_id"] | None: ...
global___RelationCommon = RelationCommon
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index a723163cbe8..9e9341c9a2a 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -356,6 +356,68 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
)
self.assert_eq(joined_plan3.toPandas(), joined_plan4.toPandas())
+ def test_join_ambiguous_cols(self):
+ # SPARK-41812: test join with ambiguous columns
+ data1 = [Row(id=1, value="foo"), Row(id=2, value=None)]
+ cdf1 = self.connect.createDataFrame(data1)
+ sdf1 = self.spark.createDataFrame(data1)
+
+ data2 = [Row(value="bar"), Row(value=None), Row(value="foo")]
+ cdf2 = self.connect.createDataFrame(data2)
+ sdf2 = self.spark.createDataFrame(data2)
+
+ cdf3 = cdf1.join(cdf2, cdf1["value"] == cdf2["value"])
+ sdf3 = sdf1.join(sdf2, sdf1["value"] == sdf2["value"])
+
+ self.assertEqual(cdf3.schema, sdf3.schema)
+ self.assertEqual(cdf3.collect(), sdf3.collect())
+
+ cdf4 = cdf1.join(cdf2, cdf1["value"].eqNullSafe(cdf2["value"]))
+ sdf4 = sdf1.join(sdf2, sdf1["value"].eqNullSafe(sdf2["value"]))
+
+ self.assertEqual(cdf4.schema, sdf4.schema)
+ self.assertEqual(cdf4.collect(), sdf4.collect())
+
+ cdf5 = cdf1.join(
+ cdf2, (cdf1["value"] == cdf2["value"]) & (cdf1["value"].eqNullSafe(cdf2["value"]))
+ )
+ sdf5 = sdf1.join(
+ sdf2, (sdf1["value"] == sdf2["value"]) & (sdf1["value"].eqNullSafe(sdf2["value"]))
+ )
+
+ self.assertEqual(cdf5.schema, sdf5.schema)
+ self.assertEqual(cdf5.collect(), sdf5.collect())
+
+ cdf6 = cdf1.join(cdf2, cdf1["value"] == cdf2["value"]).select(cdf1.value)
+ sdf6 = sdf1.join(sdf2, sdf1["value"] == sdf2["value"]).select(sdf1.value)
+
+ self.assertEqual(cdf6.schema, sdf6.schema)
+ self.assertEqual(cdf6.collect(), sdf6.collect())
+
+ cdf7 = cdf1.join(cdf2, cdf1["value"] == cdf2["value"]).select(cdf2.value)
+ sdf7 = sdf1.join(sdf2, sdf1["value"] == sdf2["value"]).select(sdf2.value)
+
+ self.assertEqual(cdf7.schema, sdf7.schema)
+ self.assertEqual(cdf7.collect(), sdf7.collect())
+
+ def test_invalid_column(self):
+ # SPARK-41812: fail df1.select(df2.col)
+ data1 = [Row(a=1, b=2, c=3)]
+ cdf1 = self.connect.createDataFrame(data1)
+
+ data2 = [Row(a=2, b=0)]
+ cdf2 = self.connect.createDataFrame(data2)
+
+ with self.assertRaises(AnalysisException):
+ cdf1.select(cdf2.a).schema
+
+ with self.assertRaises(AnalysisException):
+ cdf2.withColumn("x", cdf1.a + 1).schema
+
+ with self.assertRaisesRegex(AnalysisException, "attribute.*missing"):
+ cdf3 = cdf1.select(cdf1.a)
+ cdf3.select(cdf1.b).schema
+
def test_collect(self):
cdf = self.connect.read.table(self.tbl_name)
sdf = self.spark.read.table(self.tbl_name)
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py b/python/pyspark/sql/tests/connect/test_connect_plan.py
index 1892e64f8f9..a5f691d0bef 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan.py
@@ -85,7 +85,18 @@ class SparkConnectPlanTests(PlanOnlyTestFixture):
right_input = self.connect.readTable(table_name=self.tbl_name)
crossJoin_plan = left_input.crossJoin(other=right_input)._plan.to_proto(self.connect)
join_plan = left_input.join(other=right_input, how="cross")._plan.to_proto(self.connect)
- self.assertEqual(crossJoin_plan, join_plan)
+ self.assertEqual(
+ crossJoin_plan.root.join.left.read.named_table,
+ join_plan.root.join.left.read.named_table,
+ )
+ self.assertEqual(
+ crossJoin_plan.root.join.right.read.named_table,
+ join_plan.root.join.right.read.named_table,
+ )
+ self.assertEqual(
+ crossJoin_plan.root.join.join_type,
+ join_plan.root.join.join_type,
+ )
def test_filter(self):
df = self.connect.readTable(table_name=self.tbl_name)
@@ -732,7 +743,12 @@ class SparkConnectPlanTests(PlanOnlyTestFixture):
self.assertIsNotNone(cp1)
self.assertEqual(cp1, cp2)
- self.assertEqual(cp2, cp3)
+ self.assertEqual(
+ cp2.unresolved_attribute.unparsed_identifier,
+ cp3.unresolved_attribute.unparsed_identifier,
+ )
+ self.assertTrue(cp2.unresolved_attribute.HasField("plan_id"))
+ self.assertFalse(cp3.unresolved_attribute.HasField("plan_id"))
def test_null_literal(self):
null_lit = lit(None)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 9a2648a79a5..eff8c114a97 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -3307,12 +3307,14 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
_.containsPattern(NATURAL_LIKE_JOIN), ruleId) {
case j @ Join(left, right, UsingJoin(joinType, usingCols), _, hint)
if left.resolved && right.resolved && j.duplicateResolved =>
- commonNaturalJoinProcessing(left, right, joinType, usingCols, None, hint)
+ commonNaturalJoinProcessing(left, right, joinType, usingCols, None, hint,
+ j.getTagValue(LogicalPlan.PLAN_ID_TAG))
case j @ Join(left, right, NaturalJoin(joinType), condition, hint)
if j.resolvedExceptNatural =>
// find common column names from both sides
val joinNames = left.output.map(_.name).intersect(right.output.map(_.name))
- commonNaturalJoinProcessing(left, right, joinType, joinNames, condition, hint)
+ commonNaturalJoinProcessing(left, right, joinType, joinNames, condition, hint,
+ j.getTagValue(LogicalPlan.PLAN_ID_TAG))
}
}
@@ -3442,7 +3444,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
joinType: JoinType,
joinNames: Seq[String],
condition: Option[Expression],
- hint: JoinHint): LogicalPlan = {
+ hint: JoinHint,
+ planId: Option[Long] = None): LogicalPlan = {
import org.apache.spark.sql.catalyst.util._
val leftKeys = joinNames.map { keyName =>
@@ -3483,9 +3486,14 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
case _ =>
throw QueryExecutionErrors.unsupportedNaturalJoinTypeError(joinType)
}
+
+ val newJoin = Join(left, right, joinType, newCondition, hint)
+ // retain the plan id used in Spark Connect
+ planId.foreach(newJoin.setTagValue(LogicalPlan.PLAN_ID_TAG, _))
+
// use Project to hide duplicated common keys
// propagate hidden columns from nested USING/NATURAL JOINs
- val project = Project(projectList, Join(left, right, joinType, newCondition, hint))
+ val project = Project(projectList, newJoin)
project.setTagValue(
Project.hiddenOutputTag,
hiddenList.map(_.markAsQualifiedAccessOnly()) ++
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
index 0985fe6852c..ba550bce791 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
@@ -357,8 +357,21 @@ trait ColumnResolutionHelper extends Logging {
e: Expression,
q: LogicalPlan,
allowOuter: Boolean = false): Expression = {
+ val newE = if (e.exists(_.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty)) {
+ // If the TreeNodeTag 'LogicalPlan.PLAN_ID_TAG' is attached, it means that the plan and
+ // expression are from Spark Connect, and need to be resolved in this way:
+ // 1, extract the attached plan id from the expression (UnresolvedAttribute only for now);
+ // 2, top-down traverse the query plan to find the plan node that matches the plan id;
+ // 3, if can not find the matching node, fail the analysis due to illegal references;
+ // 4, resolve the expression with the matching node, if any error occurs here, apply the
+ // old code path;
+ resolveExpressionByPlanId(e, q)
+ } else {
+ e
+ }
+
resolveExpression(
- e,
+ newE,
resolveColumnByName = nameParts => {
q.resolveChildren(nameParts, conf.resolver)
},
@@ -369,4 +382,47 @@ trait ColumnResolutionHelper extends Logging {
throws = true,
allowOuter = allowOuter)
}
+
+ private def resolveExpressionByPlanId(
+ e: Expression,
+ q: LogicalPlan): Expression = {
+ if (!e.exists(_.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty)) {
+ return e
+ }
+
+ e match {
+ case u: UnresolvedAttribute =>
+ resolveUnresolvedAttributeByPlanId(u, q).getOrElse(u)
+ case _ =>
+ e.mapChildren(c => resolveExpressionByPlanId(c, q))
+ }
+ }
+
+ private def resolveUnresolvedAttributeByPlanId(
+ u: UnresolvedAttribute,
+ q: LogicalPlan): Option[NamedExpression] = {
+ val planIdOpt = u.getTagValue(LogicalPlan.PLAN_ID_TAG)
+ if (planIdOpt.isEmpty) return None
+ val planId = planIdOpt.get
+ logDebug(s"Extract plan_id $planId from $u")
+
+ val planOpt = q.find(_.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(planId))
+ if (planOpt.isEmpty) {
+ // For example:
+ // df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]])
+ // df2 = spark.createDataFrame([Row(a = 1, b = 2)]])
+ // df1.select(df2.a) <- illegal reference df2.a
+ throw new AnalysisException(s"When resolving $u, " +
+ s"fail to find subplan with plan_id=$planId in $q")
+ }
+ val plan = planOpt.get
+
+ try {
+ plan.resolve(u.nameParts, conf.resolver)
+ } catch {
+ case e: AnalysisException =>
+ logDebug(s"Fail to resolve $u with $plan due to $e")
+ None
+ }
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 5a7dcff3667..36187bb2d55 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{AliasAwareQueryOutputOrdering, QueryPlan}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.LogicalPlanStats
-import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, UnaryLike}
+import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, TreeNodeTag, UnaryLike}
import org.apache.spark.sql.catalyst.util.MetadataColumnHelper
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.types.{DataType, StructType}
@@ -157,6 +157,18 @@ abstract class LogicalPlan
}
}
+object LogicalPlan {
+ // A dedicated tag for Spark Connect.
+ // If an expression (only support UnresolvedAttribute for now) was attached by this tag,
+ // the analyzer will:
+ // 1, extract the plan id;
+ // 2, top-down traverse the query plan to find the node that was attached by the same tag.
+ // and fails the whole analysis if can not find it;
+ // 3, resolve this expression with the matching node. If any error occurs, analyzer fallbacks
+ // to the old code path.
+ private[spark] val PLAN_ID_TAG = TreeNodeTag[Long]("plan_id")
+}
+
/**
* A logical plan node with no children.
*/
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org