You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/02/13 08:22:42 UTC

[spark] branch master updated: [SPARK-41812][SPARK-41823][CONNECT][SQL][PYTHON] Resolve ambiguous columns issue in `Join`

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 167bbca49c1 [SPARK-41812][SPARK-41823][CONNECT][SQL][PYTHON] Resolve ambiguous columns issue in `Join`
167bbca49c1 is described below

commit 167bbca49c1c12ccd349d4330862c136b38d4522
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Mon Feb 13 16:22:06 2023 +0800

    [SPARK-41812][SPARK-41823][CONNECT][SQL][PYTHON] Resolve ambiguous columns issue in `Join`
    
    ### What changes were proposed in this pull request?
    
    In Python Client
    - generate `plan_id` for each proto plan (It's up to the Client to guarantee the uniqueness);
    - attach `plan_id` to the column created by `DataFrame[col_name]` or `DataFrame.col_name`;
    - Note that `F.col(col_name)` doesn't have `plan_id`;
    
    In Connect Planner:
    - attach `plan_id` to `UnresolvedAttribute`s and `LogicalPlan `s via `TreeNodeTag`
    
    In Analyzer:
    - for an `UnresolvedAttribute` with `plan_id`, search the matching node in the plan, and resolve it with the found node if possible
    
    **Out of scope:**
    
    - resolve `self-join`
    - add a `DetectAmbiguousSelfJoin`-like rule for detection
    
    ### Why are the changes needed?
    Fix bug, before this PR:
    ```
    df1.join(df2, df1["value"] == df2["value"])  <- fail due to can not resolve `value`
    df1.join(df2, df1["value"] == df2["value"]).select(df1.value) <- fail due to can not resolve `value`
    df1.select(df2.value)    <- should fail, but run as `df1.select(df1.value)` and return the incorrect results
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    added tests, enabled tests
    
    Closes #39925 from zhengruifeng/connect_plan_id.
    
    Authored-by: Ruifeng Zheng <ru...@apache.org>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 .../main/protobuf/spark/connect/expressions.proto  |   3 +
 .../main/protobuf/spark/connect/relations.proto    |   3 +
 .../sql/connect/planner/SparkConnectPlanner.scala  |  22 ++-
 python/pyspark/sql/column.py                       |   5 +-
 python/pyspark/sql/connect/dataframe.py            |  15 +-
 python/pyspark/sql/connect/expressions.py          |   7 +-
 python/pyspark/sql/connect/functions.py            |  19 +-
 python/pyspark/sql/connect/plan.py                 | 213 +++++++++++----------
 .../pyspark/sql/connect/proto/expressions_pb2.py   |  54 +++---
 .../pyspark/sql/connect/proto/expressions_pb2.pyi  |  20 +-
 python/pyspark/sql/connect/proto/relations_pb2.py  | 200 +++++++++----------
 python/pyspark/sql/connect/proto/relations_pb2.pyi |  15 +-
 .../sql/tests/connect/test_connect_basic.py        |  62 ++++++
 .../pyspark/sql/tests/connect/test_connect_plan.py |  20 +-
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  16 +-
 .../catalyst/analysis/ColumnResolutionHelper.scala |  58 +++++-
 .../sql/catalyst/plans/logical/LogicalPlan.scala   |  14 +-
 17 files changed, 481 insertions(+), 265 deletions(-)

diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 8682e1ee27b..1929d9cdca3 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -197,6 +197,9 @@ message Expression {
     // (Required) An identifier that will be parsed by Catalyst parser. This should follow the
     // Spark SQL identifier syntax.
     string unparsed_identifier = 1;
+
+    // (Optional) The id of corresponding connect plan.
+    optional int64 plan_id = 2;
   }
 
   // An unresolved function is not explicitly bound to one explicit function, but the function
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index ea1216957d8..29fffd65c75 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -93,6 +93,9 @@ message Unknown {}
 message RelationCommon {
   // (Required) Shared relation metadata.
   string source_info = 1;
+
+  // (Optional) A per-client globally unique id for a given connect plan.
+  optional int64 plan_id = 2;
 }
 
 // Relation that uses a SQL query to generate the output.
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 740d6b85964..53d494cdcb7 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -66,7 +66,7 @@ class SparkConnectPlanner(val session: SparkSession) {
 
   // The root of the query plan is a relation and we apply the transformations to it.
   def transformRelation(rel: proto.Relation): LogicalPlan = {
-    rel.getRelTypeCase match {
+    val plan = rel.getRelTypeCase match {
       // DataFrame API
       case proto.Relation.RelTypeCase.SHOW_STRING => transformShowString(rel.getShowString)
       case proto.Relation.RelTypeCase.READ => transformReadRel(rel.getRead)
@@ -124,6 +124,11 @@ class SparkConnectPlanner(val session: SparkSession) {
         transformRelationPlugin(rel.getExtension)
       case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.")
     }
+
+    if (rel.hasCommon && rel.getCommon.hasPlanId) {
+      plan.setTagValue(LogicalPlan.PLAN_ID_TAG, rel.getCommon.getPlanId)
+    }
+    plan
   }
 
   private def transformRelationPlugin(extension: ProtoAny): LogicalPlan = {
@@ -702,10 +707,6 @@ class SparkConnectPlanner(val session: SparkSession) {
     logical.Project(projectList = projection, child = baseRel)
   }
 
-  private def transformUnresolvedExpression(exp: proto.Expression): UnresolvedAttribute = {
-    UnresolvedAttribute.quotedString(exp.getUnresolvedAttribute.getUnparsedIdentifier)
-  }
-
   /**
    * Transforms an input protobuf expression into the Catalyst expression. This is usually not
    * called directly. Typically the planner will traverse the expressions automatically, only
@@ -720,7 +721,7 @@ class SparkConnectPlanner(val session: SparkSession) {
     exp.getExprTypeCase match {
       case proto.Expression.ExprTypeCase.LITERAL => transformLiteral(exp.getLiteral)
       case proto.Expression.ExprTypeCase.UNRESOLVED_ATTRIBUTE =>
-        transformUnresolvedExpression(exp)
+        transformUnresolvedAttribute(exp.getUnresolvedAttribute)
       case proto.Expression.ExprTypeCase.UNRESOLVED_FUNCTION =>
         transformUnregisteredFunction(exp.getUnresolvedFunction)
           .getOrElse(transformUnresolvedFunction(exp.getUnresolvedFunction))
@@ -758,6 +759,15 @@ class SparkConnectPlanner(val session: SparkSession) {
     case expr => UnresolvedAlias(expr)
   }
 
+  private def transformUnresolvedAttribute(
+      attr: proto.Expression.UnresolvedAttribute): UnresolvedAttribute = {
+    val expr = UnresolvedAttribute.quotedString(attr.getUnparsedIdentifier)
+    if (attr.hasPlanId) {
+      expr.setTagValue(LogicalPlan.PLAN_ID_TAG, attr.getPlanId)
+    }
+    expr
+  }
+
   private def transformExpressionPlugin(extension: ProtoAny): Expression = {
     SparkConnectPluginRegistry.expressionRegistry
       // Lazily traverse the collection.
diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py
index a790f191110..0b5f94cfaaa 100644
--- a/python/pyspark/sql/column.py
+++ b/python/pyspark/sql/column.py
@@ -279,7 +279,6 @@ class Column:
     __ge__ = _bin_op("geq")
     __gt__ = _bin_op("gt")
 
-    # TODO(SPARK-41812): DataFrame.join: ambiguous column
     _eqNullSafe_doc = """
     Equality test that is safe for null values.
 
@@ -315,9 +314,9 @@ class Column:
     ...     Row(value = 'bar'),
     ...     Row(value = None)
     ... ])
-    >>> df1.join(df2, df1["value"] == df2["value"]).count()  # doctest: +SKIP
+    >>> df1.join(df2, df1["value"] == df2["value"]).count()
     0
-    >>> df1.join(df2, df1["value"].eqNullSafe(df2["value"])).count()  # doctest: +SKIP
+    >>> df1.join(df2, df1["value"].eqNullSafe(df2["value"])).count()
     1
     >>> df2 = spark.createDataFrame([
     ...     Row(id=1, value=float('NaN')),
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index 95e39f93dc0..667295e8667 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -50,12 +50,14 @@ from pyspark.sql.dataframe import (
 )
 
 from pyspark.errors import PySparkTypeError
+from pyspark.errors.exceptions.connect import SparkConnectException
 import pyspark.sql.connect.plan as plan
 from pyspark.sql.connect.group import GroupedData
 from pyspark.sql.connect.readwriter import DataFrameWriter, DataFrameWriterV2
 from pyspark.sql.connect.column import Column
 from pyspark.sql.connect.expressions import UnresolvedRegex
 from pyspark.sql.connect.functions import (
+    _to_col_with_plan_id,
     _to_col,
     _invoke_function,
     col,
@@ -1284,10 +1286,12 @@ class DataFrame:
         if isinstance(item, str):
             # Check for alias
             alias = self._get_alias()
-            if alias is not None:
-                return col(alias)
-            else:
-                return col(item)
+            if self._plan is None:
+                raise SparkConnectException("Cannot analyze on empty plan.")
+            return _to_col_with_plan_id(
+                col=alias if alias is not None else item,
+                plan_id=self._plan._plan_id,
+            )
         elif isinstance(item, Column):
             return self.filter(item)
         elif isinstance(item, (list, tuple)):
@@ -1694,9 +1698,8 @@ def _test() -> None:
     del pyspark.sql.connect.dataframe.DataFrame.repartition.__doc__
     del pyspark.sql.connect.dataframe.DataFrame.repartitionByRange.__doc__
 
-    # TODO(SPARK-41823): ambiguous column names
+    # TODO(SPARK-42367): DataFrame.drop should handle duplicated columns
     del pyspark.sql.connect.dataframe.DataFrame.drop.__doc__
-    del pyspark.sql.connect.dataframe.DataFrame.join.__doc__
 
     # TODO(SPARK-41625): Support Structured Streaming
     del pyspark.sql.connect.dataframe.DataFrame.isStreaming.__doc__
diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py
index 28b796496ec..571dd2b2f4b 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -339,11 +339,14 @@ class ColumnReference(Expression):
     treat it as an unresolved attribute. Attributes that have the same fully
     qualified name are identical"""
 
-    def __init__(self, unparsed_identifier: str) -> None:
+    def __init__(self, unparsed_identifier: str, plan_id: Optional[int] = None) -> None:
         super().__init__()
         assert isinstance(unparsed_identifier, str)
         self._unparsed_identifier = unparsed_identifier
 
+        assert plan_id is None or isinstance(plan_id, int)
+        self._plan_id = plan_id
+
     def name(self) -> str:
         """Returns the qualified name of the column reference."""
         return self._unparsed_identifier
@@ -352,6 +355,8 @@ class ColumnReference(Expression):
         """Returns the Proto representation of the expression."""
         expr = proto.Expression()
         expr.unresolved_attribute.unparsed_identifier = self._unparsed_identifier
+        if self._plan_id is not None:
+            expr.unresolved_attribute.plan_id = self._plan_id
         return expr
 
     def __repr__(self) -> str:
diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py
index d4984b1ba67..e5305938797 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -63,6 +63,15 @@ if TYPE_CHECKING:
     from pyspark.sql.connect.dataframe import DataFrame
 
 
+def _to_col_with_plan_id(col: str, plan_id: Optional[int]) -> Column:
+    if col == "*":
+        return Column(UnresolvedStar(unparsed_target=None))
+    elif col.endswith(".*"):
+        return Column(UnresolvedStar(unparsed_target=col))
+    else:
+        return Column(ColumnReference(unparsed_identifier=col, plan_id=plan_id))
+
+
 def _to_col(col: "ColumnOrName") -> Column:
     assert isinstance(col, (Column, str))
     return col if isinstance(col, Column) else column(col)
@@ -202,12 +211,7 @@ def _options_to_col(options: Dict[str, Any]) -> Column:
 
 
 def col(col: str) -> Column:
-    if col == "*":
-        return Column(UnresolvedStar(unparsed_target=None))
-    elif col.endswith(".*"):
-        return Column(UnresolvedStar(unparsed_target=col))
-    else:
-        return Column(ColumnReference(unparsed_identifier=col))
+    return _to_col_with_plan_id(col=col, plan_id=None)
 
 
 col.__doc__ = pysparkfuncs.col.__doc__
@@ -2470,9 +2474,6 @@ def _test() -> None:
     del pyspark.sql.connect.functions.timestamp_seconds.__doc__
     del pyspark.sql.connect.functions.unix_timestamp.__doc__
 
-    # TODO(SPARK-41812): Proper column names after join
-    del pyspark.sql.connect.functions.count_distinct.__doc__
-
     # TODO(SPARK-41843): Implement SparkSession.udf
     del pyspark.sql.connect.functions.call_udf.__doc__
 
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index ced0e4008e1..d37201e4408 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -21,9 +21,11 @@ check_dependencies(__name__, __file__)
 from typing import Any, List, Optional, Sequence, Union, cast, TYPE_CHECKING, Mapping, Dict
 import functools
 import json
-import pyarrow as pa
+from threading import Lock
 from inspect import signature, isclass
 
+import pyarrow as pa
+
 from pyspark.sql.types import DataType
 
 import pyspark.sql.connect.proto as proto
@@ -40,13 +42,29 @@ class InputValidationError(Exception):
     pass
 
 
-class LogicalPlan(object):
+class LogicalPlan:
+
+    _lock: Lock = Lock()
+    _nextPlanId: int = 0
 
     INDENT = 2
 
     def __init__(self, child: Optional["LogicalPlan"]) -> None:
         self._child = child
 
+        plan_id: Optional[int] = None
+        with LogicalPlan._lock:
+            plan_id = LogicalPlan._nextPlanId
+            LogicalPlan._nextPlanId += 1
+
+        assert plan_id is not None
+        self._plan_id = plan_id
+
+    def _create_proto_relation(self) -> proto.Relation:
+        plan = proto.Relation()
+        plan.common.plan_id = self._plan_id
+        return plan
+
     def unresolved_attr(self, colName: str) -> proto.Expression:
         """Creates an unresolved attribute from a column name."""
         exp = proto.Expression()
@@ -258,7 +276,7 @@ class DataSource(LogicalPlan):
         self._paths = paths
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        plan = proto.Relation()
+        plan = self._create_proto_relation()
         plan.read.data_source.format = self._format
         if self._schema is not None:
             plan.read.data_source.schema = self._schema
@@ -276,7 +294,7 @@ class Read(LogicalPlan):
         self.table_name = table_name
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        plan = proto.Relation()
+        plan = self._create_proto_relation()
         plan.read.named_table.unparsed_identifier = self.table_name
         return plan
 
@@ -306,8 +324,7 @@ class LocalRelation(LogicalPlan):
         self._schema = schema
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        plan = proto.Relation()
-
+        plan = self._create_proto_relation()
         if self._table is not None:
             sink = pa.BufferOutputStream()
             with pa.ipc.new_stream(sink, self._table.schema) as writer:
@@ -341,7 +358,7 @@ class ShowString(LogicalPlan):
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
-        plan = proto.Relation()
+        plan = self._create_proto_relation()
         plan.show_string.input.CopyFrom(self._child.plan(session))
         plan.show_string.num_rows = self.num_rows
         plan.show_string.truncate = self.truncate
@@ -378,6 +395,8 @@ class Project(LogicalPlan):
         from pyspark.sql.connect.functions import col
 
         assert self._child is not None
+        plan = self._create_proto_relation()
+        plan.project.input.CopyFrom(self._child.plan(session))
 
         proj_exprs = []
         for c in self._columns:
@@ -386,8 +405,6 @@ class Project(LogicalPlan):
             else:
                 proj_exprs.append(col(c).to_plan(session))
 
-        plan = proto.Relation()
-        plan.project.input.CopyFrom(self._child.plan(session))
         plan.project.expressions.extend(proj_exprs)
         return plan
 
@@ -426,7 +443,7 @@ class WithColumns(LogicalPlan):
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
-        plan = proto.Relation()
+        plan = self._create_proto_relation()
         plan.with_columns.input.CopyFrom(self._child.plan(session))
 
         for i in range(0, len(self._columnNames)):
@@ -461,7 +478,7 @@ class Hint(LogicalPlan):
         from pyspark.sql.connect.functions import array, lit
 
         assert self._child is not None
-        plan = proto.Relation()
+        plan = self._create_proto_relation()
         plan.hint.input.CopyFrom(self._child.plan(session))
         plan.hint.name = self._name
         for param in self._parameters:
@@ -479,7 +496,7 @@ class Filter(LogicalPlan):
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
-        plan = proto.Relation()
+        plan = self._create_proto_relation()
         plan.filter.input.CopyFrom(self._child.plan(session))
         plan.filter.condition.CopyFrom(self.filter.to_plan(session))
         return plan
@@ -492,7 +509,7 @@ class Limit(LogicalPlan):
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
-        plan = proto.Relation()
+        plan = self._create_proto_relation()
         plan.limit.input.CopyFrom(self._child.plan(session))
         plan.limit.limit = self.limit
         return plan
@@ -505,7 +522,7 @@ class Tail(LogicalPlan):
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
-        plan = proto.Relation()
+        plan = self._create_proto_relation()
         plan.tail.input.CopyFrom(self._child.plan(session))
         plan.tail.limit = self.limit
         return plan
@@ -518,7 +535,7 @@ class Offset(LogicalPlan):
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
-        plan = proto.Relation()
+        plan = self._create_proto_relation()
         plan.offset.input.CopyFrom(self._child.plan(session))
         plan.offset.offset = self.offset
         return plan
@@ -537,7 +554,7 @@ class Deduplicate(LogicalPlan):
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
-        plan = proto.Relation()
+        plan = self._create_proto_relation()
         plan.deduplicate.input.CopyFrom(self._child.plan(session))
         plan.deduplicate.all_columns_as_keys = self.all_columns_as_keys
         if self.column_names is not None:
@@ -570,7 +587,7 @@ class Sort(LogicalPlan):
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
-        plan = proto.Relation()
+        plan = self._create_proto_relation()
         plan.sort.input.CopyFrom(self._child.plan(session))
         plan.sort.order.extend([self._convert_col(c, session) for c in self.columns])
         plan.sort.is_global = self.is_global
@@ -599,7 +616,7 @@ class Drop(LogicalPlan):
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
-        plan = proto.Relation()
+        plan = self._create_proto_relation()
         plan.drop.input.CopyFrom(self._child.plan(session))
         plan.drop.cols.extend([self._convert_to_expr(c, session) for c in self.columns])
         return plan
@@ -624,7 +641,7 @@ class Sample(LogicalPlan):
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
-        plan = proto.Relation()
+        plan = self._create_proto_relation()
         plan.sample.input.CopyFrom(self._child.plan(session))
         plan.sample.lower_bound = self.lower_bound
         plan.sample.upper_bound = self.upper_bound
@@ -672,32 +689,31 @@ class Aggregate(LogicalPlan):
         from pyspark.sql.connect.functions import lit
 
         assert self._child is not None
-
-        agg = proto.Relation()
-
-        agg.aggregate.input.CopyFrom(self._child.plan(session))
-
-        agg.aggregate.grouping_expressions.extend([c.to_plan(session) for c in self._grouping_cols])
-        agg.aggregate.aggregate_expressions.extend(
+        plan = self._create_proto_relation()
+        plan.aggregate.input.CopyFrom(self._child.plan(session))
+        plan.aggregate.grouping_expressions.extend(
+            [c.to_plan(session) for c in self._grouping_cols]
+        )
+        plan.aggregate.aggregate_expressions.extend(
             [c.to_plan(session) for c in self._aggregate_cols]
         )
 
         if self._group_type == "groupby":
-            agg.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY
+            plan.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_GROUPBY
         elif self._group_type == "rollup":
-            agg.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP
+            plan.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_ROLLUP
         elif self._group_type == "cube":
-            agg.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_CUBE
+            plan.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_CUBE
         elif self._group_type == "pivot":
-            agg.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_PIVOT
+            plan.aggregate.group_type = proto.Aggregate.GroupType.GROUP_TYPE_PIVOT
             assert self._pivot_col is not None
-            agg.aggregate.pivot.col.CopyFrom(self._pivot_col.to_plan(session))
+            plan.aggregate.pivot.col.CopyFrom(self._pivot_col.to_plan(session))
             if self._pivot_values is not None and len(self._pivot_values) > 0:
-                agg.aggregate.pivot.values.extend(
+                plan.aggregate.pivot.values.extend(
                     [lit(v).to_plan(session).literal for v in self._pivot_values]
                 )
 
-        return agg
+        return plan
 
 
 class Join(LogicalPlan):
@@ -742,23 +758,23 @@ class Join(LogicalPlan):
         self.how = join_type
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        rel = proto.Relation()
-        rel.join.left.CopyFrom(self.left.plan(session))
-        rel.join.right.CopyFrom(self.right.plan(session))
+        plan = self._create_proto_relation()
+        plan.join.left.CopyFrom(self.left.plan(session))
+        plan.join.right.CopyFrom(self.right.plan(session))
         if self.on is not None:
             if not isinstance(self.on, list):
                 if isinstance(self.on, str):
-                    rel.join.using_columns.append(self.on)
+                    plan.join.using_columns.append(self.on)
                 else:
-                    rel.join.join_condition.CopyFrom(self.to_attr_or_expression(self.on, session))
+                    plan.join.join_condition.CopyFrom(self.to_attr_or_expression(self.on, session))
             elif len(self.on) > 0:
                 if isinstance(self.on[0], str):
-                    rel.join.using_columns.extend(cast(str, self.on))
+                    plan.join.using_columns.extend(cast(str, self.on))
                 else:
                     merge_column = functools.reduce(lambda c1, c2: c1 & c2, self.on)
-                    rel.join.join_condition.CopyFrom(cast(Column, merge_column).to_plan(session))
-        rel.join.join_type = self.how
-        return rel
+                    plan.join.join_condition.CopyFrom(cast(Column, merge_column).to_plan(session))
+        plan.join.join_type = self.how
+        return plan
 
     def print(self, indent: int = 0) -> str:
         i = " " * indent
@@ -800,29 +816,29 @@ class SetOperation(LogicalPlan):
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
-        rel = proto.Relation()
+        plan = self._create_proto_relation()
         if self._child is not None:
-            rel.set_op.left_input.CopyFrom(self._child.plan(session))
+            plan.set_op.left_input.CopyFrom(self._child.plan(session))
         if self.other is not None:
-            rel.set_op.right_input.CopyFrom(self.other.plan(session))
+            plan.set_op.right_input.CopyFrom(self.other.plan(session))
         if self.set_op == "union":
-            rel.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_UNION
+            plan.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_UNION
         elif self.set_op == "intersect":
-            rel.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_INTERSECT
+            plan.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_INTERSECT
         elif self.set_op == "except":
-            rel.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_EXCEPT
+            plan.set_op.set_op_type = proto.SetOperation.SET_OP_TYPE_EXCEPT
         else:
             raise NotImplementedError(
                 """
                 Unsupported set operation type: %s.
                 """
-                % rel.set_op.set_op_type
+                % plan.set_op.set_op_type
             )
 
-        rel.set_op.is_all = self.is_all
-        rel.set_op.by_name = self.by_name
-        rel.set_op.allow_missing_columns = self.allow_missing_columns
-        return rel
+        plan.set_op.is_all = self.is_all
+        plan.set_op.by_name = self.by_name
+        plan.set_op.allow_missing_columns = self.allow_missing_columns
+        return plan
 
     def print(self, indent: int = 0) -> str:
         assert self._child is not None
@@ -860,12 +876,12 @@ class Repartition(LogicalPlan):
         self._shuffle = shuffle
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        rel = proto.Relation()
+        plan = self._create_proto_relation()
         if self._child is not None:
-            rel.repartition.input.CopyFrom(self._child.plan(session))
-        rel.repartition.shuffle = self._shuffle
-        rel.repartition.num_partitions = self._num_partitions
-        return rel
+            plan.repartition.input.CopyFrom(self._child.plan(session))
+        plan.repartition.shuffle = self._shuffle
+        plan.repartition.num_partitions = self._num_partitions
+        return plan
 
 
 class RepartitionByExpression(LogicalPlan):
@@ -882,7 +898,7 @@ class RepartitionByExpression(LogicalPlan):
         self.columns = columns
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        rel = proto.Relation()
+        plan = self._create_proto_relation()
 
         part_exprs = []
         for c in self.columns:
@@ -894,13 +910,13 @@ class RepartitionByExpression(LogicalPlan):
                 part_exprs.append(exp)
             else:
                 part_exprs.append(self.unresolved_attr(c))
-        rel.repartition_by_expression.partition_exprs.extend(part_exprs)
+        plan.repartition_by_expression.partition_exprs.extend(part_exprs)
 
         if self._child is not None:
-            rel.repartition_by_expression.input.CopyFrom(self._child.plan(session))
+            plan.repartition_by_expression.input.CopyFrom(self._child.plan(session))
         if self.num_partitions is not None:
-            rel.repartition_by_expression.num_partitions = self.num_partitions
-        return rel
+            plan.repartition_by_expression.num_partitions = self.num_partitions
+        return plan
 
 
 class SubqueryAlias(LogicalPlan):
@@ -911,11 +927,11 @@ class SubqueryAlias(LogicalPlan):
         self._alias = alias
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        rel = proto.Relation()
+        plan = self._create_proto_relation()
         if self._child is not None:
-            rel.subquery_alias.input.CopyFrom(self._child.plan(session))
-        rel.subquery_alias.alias = self._alias
-        return rel
+            plan.subquery_alias.input.CopyFrom(self._child.plan(session))
+        plan.subquery_alias.alias = self._alias
+        return plan
 
 
 class SQL(LogicalPlan):
@@ -931,14 +947,14 @@ class SQL(LogicalPlan):
         self._args = args
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        rel = proto.Relation()
-        rel.sql.query = self._query
+        plan = self._create_proto_relation()
+        plan.sql.query = self._query
 
         if self._args is not None and len(self._args) > 0:
             for k, v in self._args.items():
-                rel.sql.args[k] = v
+                plan.sql.args[k] = v
 
-        return rel
+        return plan
 
 
 class Range(LogicalPlan):
@@ -956,13 +972,13 @@ class Range(LogicalPlan):
         self._num_partitions = num_partitions
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
-        rel = proto.Relation()
-        rel.range.start = self._start
-        rel.range.end = self._end
-        rel.range.step = self._step
+        plan = self._create_proto_relation()
+        plan.range.start = self._start
+        plan.range.end = self._end
+        plan.range.step = self._step
         if self._num_partitions is not None:
-            rel.range.num_partitions = self._num_partitions
-        return rel
+            plan.range.num_partitions = self._num_partitions
+        return plan
 
 
 class ToSchema(LogicalPlan):
@@ -972,8 +988,7 @@ class ToSchema(LogicalPlan):
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
-
-        plan = proto.Relation()
+        plan = self._create_proto_relation()
         plan.to_schema.input.CopyFrom(self._child.plan(session))
         plan.to_schema.schema.CopyFrom(pyspark_types_to_proto_types(self._schema))
         return plan
@@ -986,8 +1001,7 @@ class WithColumnsRenamed(LogicalPlan):
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
-
-        plan = proto.Relation()
+        plan = self._create_proto_relation()
         plan.with_columns_renamed.input.CopyFrom(self._child.plan(session))
         for k, v in self._colsMap.items():
             plan.with_columns_renamed.rename_columns_map[k] = v
@@ -1019,8 +1033,7 @@ class Unpivot(LogicalPlan):
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
-
-        plan = proto.Relation()
+        plan = self._create_proto_relation()
         plan.unpivot.input.CopyFrom(self._child.plan(session))
         plan.unpivot.ids.extend([self.col_to_expr(x, session) for x in self.ids])
         if self.values is not None:
@@ -1064,7 +1077,7 @@ class NAFill(LogicalPlan):
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
-        plan = proto.Relation()
+        plan = self._create_proto_relation()
         plan.fill_na.input.CopyFrom(self._child.plan(session))
         if self.cols is not None and len(self.cols) > 0:
             plan.fill_na.cols.extend(self.cols)
@@ -1086,7 +1099,7 @@ class NADrop(LogicalPlan):
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
-        plan = proto.Relation()
+        plan = self._create_proto_relation()
         plan.drop_na.input.CopyFrom(self._child.plan(session))
         if self.cols is not None and len(self.cols) > 0:
             plan.drop_na.cols.extend(self.cols)
@@ -1122,7 +1135,7 @@ class NAReplace(LogicalPlan):
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
-        plan = proto.Relation()
+        plan = self._create_proto_relation()
         plan.replace.input.CopyFrom(self._child.plan(session))
         if self.cols is not None and len(self.cols) > 0:
             plan.replace.cols.extend(self.cols)
@@ -1150,7 +1163,7 @@ class StatSummary(LogicalPlan):
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
-        plan = proto.Relation()
+        plan = self._create_proto_relation()
         plan.summary.input.CopyFrom(self._child.plan(session))
         plan.summary.statistics.extend(self.statistics)
         return plan
@@ -1163,7 +1176,7 @@ class StatDescribe(LogicalPlan):
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
-        plan = proto.Relation()
+        plan = self._create_proto_relation()
         plan.describe.input.CopyFrom(self._child.plan(session))
         plan.describe.cols.extend(self.cols)
         return plan
@@ -1177,8 +1190,7 @@ class StatCov(LogicalPlan):
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
-
-        plan = proto.Relation()
+        plan = self._create_proto_relation()
         plan.cov.input.CopyFrom(self._child.plan(session))
         plan.cov.col1 = self._col1
         plan.cov.col2 = self._col2
@@ -1200,8 +1212,7 @@ class StatApproxQuantile(LogicalPlan):
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
-
-        plan = proto.Relation()
+        plan = self._create_proto_relation()
         plan.approx_quantile.input.CopyFrom(self._child.plan(session))
         plan.approx_quantile.cols.extend(self._cols)
         plan.approx_quantile.probabilities.extend(self._probabilities)
@@ -1217,8 +1228,7 @@ class StatCrosstab(LogicalPlan):
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
-
-        plan = proto.Relation()
+        plan = self._create_proto_relation()
         plan.crosstab.input.CopyFrom(self._child.plan(session))
         plan.crosstab.col1 = self.col1
         plan.crosstab.col2 = self.col2
@@ -1238,8 +1248,7 @@ class StatFreqItems(LogicalPlan):
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
-
-        plan = proto.Relation()
+        plan = self._create_proto_relation()
         plan.freq_items.input.CopyFrom(self._child.plan(session))
         plan.freq_items.cols.extend(self._cols)
         plan.freq_items.support = self._support
@@ -1275,8 +1284,7 @@ class StatSampleBy(LogicalPlan):
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
-
-        plan = proto.Relation()
+        plan = self._create_proto_relation()
         plan.sample_by.input.CopyFrom(self._child.plan(session))
         plan.sample_by.col.CopyFrom(self._col._expr.to_plan(session))
         if len(self._fractions) > 0:
@@ -1299,8 +1307,7 @@ class StatCorr(LogicalPlan):
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
-
-        plan = proto.Relation()
+        plan = self._create_proto_relation()
         plan.corr.input.CopyFrom(self._child.plan(session))
         plan.corr.col1 = self._col1
         plan.corr.col2 = self._col2
@@ -1315,8 +1322,7 @@ class ToDF(LogicalPlan):
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
-
-        plan = proto.Relation()
+        plan = self._create_proto_relation()
         plan.to_df.input.CopyFrom(self._child.plan(session))
         plan.to_df.column_names.extend(self._cols)
         return plan
@@ -1333,8 +1339,8 @@ class CreateView(LogicalPlan):
 
     def command(self, session: "SparkConnectClient") -> proto.Command:
         assert self._child is not None
-
         plan = proto.Command()
+
         plan.create_dataframe_view.replace = self._replace
         plan.create_dataframe_view.is_global = self._is_global
         plan.create_dataframe_view.name = self._name
@@ -1358,6 +1364,7 @@ class WriteOperation(LogicalPlan):
     def command(self, session: "SparkConnectClient") -> proto.Command:
         assert self._child is not None
         plan = proto.Command()
+
         plan.write_operation.input.CopyFrom(self._child.plan(session))
         if self.source is not None:
             plan.write_operation.source = self.source
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 92d9e6a610a..891be5ea9ea 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\x92%\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...]
+    b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xbc%\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunct [...]
 )
 
 
@@ -300,7 +300,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     DESCRIPTOR._options = None
     DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001"
     _EXPRESSION._serialized_start = 105
-    _EXPRESSION._serialized_end = 4859
+    _EXPRESSION._serialized_end = 4901
     _EXPRESSION_WINDOW._serialized_start = 1475
     _EXPRESSION_WINDOW._serialized_end = 2258
     _EXPRESSION_WINDOW_WINDOWFRAME._serialized_start = 1765
@@ -324,29 +324,29 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_start = 3599
     _EXPRESSION_LITERAL_CALENDARINTERVAL._serialized_end = 3697
     _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_start = 3715
-    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 3785
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 3788
-    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 3992
-    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 3994
-    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4044
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4046
-    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4128
-    _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4130
-    _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4174
-    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4177
-    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4309
-    _EXPRESSION_UPDATEFIELDS._serialized_start = 4312
-    _EXPRESSION_UPDATEFIELDS._serialized_end = 4499
-    _EXPRESSION_ALIAS._serialized_start = 4501
-    _EXPRESSION_ALIAS._serialized_end = 4621
-    _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4624
-    _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4782
-    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4784
-    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4846
-    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 4862
-    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5173
-    _PYTHONUDF._serialized_start = 5176
-    _PYTHONUDF._serialized_end = 5306
-    _SCALARSCALAUDF._serialized_start = 5309
-    _SCALARSCALAUDF._serialized_end = 5493
+    _EXPRESSION_UNRESOLVEDATTRIBUTE._serialized_end = 3827
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_start = 3830
+    _EXPRESSION_UNRESOLVEDFUNCTION._serialized_end = 4034
+    _EXPRESSION_EXPRESSIONSTRING._serialized_start = 4036
+    _EXPRESSION_EXPRESSIONSTRING._serialized_end = 4086
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_start = 4088
+    _EXPRESSION_UNRESOLVEDSTAR._serialized_end = 4170
+    _EXPRESSION_UNRESOLVEDREGEX._serialized_start = 4172
+    _EXPRESSION_UNRESOLVEDREGEX._serialized_end = 4216
+    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_start = 4219
+    _EXPRESSION_UNRESOLVEDEXTRACTVALUE._serialized_end = 4351
+    _EXPRESSION_UPDATEFIELDS._serialized_start = 4354
+    _EXPRESSION_UPDATEFIELDS._serialized_end = 4541
+    _EXPRESSION_ALIAS._serialized_start = 4543
+    _EXPRESSION_ALIAS._serialized_end = 4663
+    _EXPRESSION_LAMBDAFUNCTION._serialized_start = 4666
+    _EXPRESSION_LAMBDAFUNCTION._serialized_end = 4824
+    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 4826
+    _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 4888
+    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 4904
+    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5215
+    _PYTHONUDF._serialized_start = 5218
+    _PYTHONUDF._serialized_end = 5348
+    _SCALARSCALAUDF._serialized_start = 5351
+    _SCALARSCALAUDF._serialized_end = 5535
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 934e0016c90..88b1fd8ef7e 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -613,19 +613,37 @@ class Expression(google.protobuf.message.Message):
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
         UNPARSED_IDENTIFIER_FIELD_NUMBER: builtins.int
+        PLAN_ID_FIELD_NUMBER: builtins.int
         unparsed_identifier: builtins.str
         """(Required) An identifier that will be parsed by Catalyst parser. This should follow the
         Spark SQL identifier syntax.
         """
+        plan_id: builtins.int
+        """(Optional) The id of corresponding connect plan."""
         def __init__(
             self,
             *,
             unparsed_identifier: builtins.str = ...,
+            plan_id: builtins.int | None = ...,
         ) -> None: ...
+        def HasField(
+            self,
+            field_name: typing_extensions.Literal["_plan_id", b"_plan_id", "plan_id", b"plan_id"],
+        ) -> builtins.bool: ...
         def ClearField(
             self,
-            field_name: typing_extensions.Literal["unparsed_identifier", b"unparsed_identifier"],
+            field_name: typing_extensions.Literal[
+                "_plan_id",
+                b"_plan_id",
+                "plan_id",
+                b"plan_id",
+                "unparsed_identifier",
+                b"unparsed_identifier",
+            ],
         ) -> None: ...
+        def WhichOneof(
+            self, oneof_group: typing_extensions.Literal["_plan_id", b"_plan_id"]
+        ) -> typing_extensions.Literal["plan_id"] | None: ...
 
     class UnresolvedFunction(google.protobuf.message.Message):
         """An unresolved function is not explicitly bound to one explicit function, but the function
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py
index ece5920953e..057b96a8da9 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as spark_dot_connect_dot_catal
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf9\x11\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+    b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf9\x11\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
 )
 
 
@@ -639,103 +639,103 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _UNKNOWN._serialized_start = 2464
     _UNKNOWN._serialized_end = 2473
     _RELATIONCOMMON._serialized_start = 2475
-    _RELATIONCOMMON._serialized_end = 2524
-    _SQL._serialized_start = 2527
-    _SQL._serialized_end = 2661
-    _SQL_ARGSENTRY._serialized_start = 2606
-    _SQL_ARGSENTRY._serialized_end = 2661
-    _READ._serialized_start = 2664
-    _READ._serialized_end = 3112
-    _READ_NAMEDTABLE._serialized_start = 2806
-    _READ_NAMEDTABLE._serialized_end = 2867
-    _READ_DATASOURCE._serialized_start = 2870
-    _READ_DATASOURCE._serialized_end = 3099
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3030
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3088
-    _PROJECT._serialized_start = 3114
-    _PROJECT._serialized_end = 3231
-    _FILTER._serialized_start = 3233
-    _FILTER._serialized_end = 3345
-    _JOIN._serialized_start = 3348
-    _JOIN._serialized_end = 3819
-    _JOIN_JOINTYPE._serialized_start = 3611
-    _JOIN_JOINTYPE._serialized_end = 3819
-    _SETOPERATION._serialized_start = 3822
-    _SETOPERATION._serialized_end = 4301
-    _SETOPERATION_SETOPTYPE._serialized_start = 4138
-    _SETOPERATION_SETOPTYPE._serialized_end = 4252
-    _LIMIT._serialized_start = 4303
-    _LIMIT._serialized_end = 4379
-    _OFFSET._serialized_start = 4381
-    _OFFSET._serialized_end = 4460
-    _TAIL._serialized_start = 4462
-    _TAIL._serialized_end = 4537
-    _AGGREGATE._serialized_start = 4540
-    _AGGREGATE._serialized_end = 5122
-    _AGGREGATE_PIVOT._serialized_start = 4879
-    _AGGREGATE_PIVOT._serialized_end = 4990
-    _AGGREGATE_GROUPTYPE._serialized_start = 4993
-    _AGGREGATE_GROUPTYPE._serialized_end = 5122
-    _SORT._serialized_start = 5125
-    _SORT._serialized_end = 5285
-    _DROP._serialized_start = 5287
-    _DROP._serialized_end = 5387
-    _DEDUPLICATE._serialized_start = 5390
-    _DEDUPLICATE._serialized_end = 5561
-    _LOCALRELATION._serialized_start = 5563
-    _LOCALRELATION._serialized_end = 5652
-    _SAMPLE._serialized_start = 5655
-    _SAMPLE._serialized_end = 5928
-    _RANGE._serialized_start = 5931
-    _RANGE._serialized_end = 6076
-    _SUBQUERYALIAS._serialized_start = 6078
-    _SUBQUERYALIAS._serialized_end = 6192
-    _REPARTITION._serialized_start = 6195
-    _REPARTITION._serialized_end = 6337
-    _SHOWSTRING._serialized_start = 6340
-    _SHOWSTRING._serialized_end = 6482
-    _STATSUMMARY._serialized_start = 6484
-    _STATSUMMARY._serialized_end = 6576
-    _STATDESCRIBE._serialized_start = 6578
-    _STATDESCRIBE._serialized_end = 6659
-    _STATCROSSTAB._serialized_start = 6661
-    _STATCROSSTAB._serialized_end = 6762
-    _STATCOV._serialized_start = 6764
-    _STATCOV._serialized_end = 6860
-    _STATCORR._serialized_start = 6863
-    _STATCORR._serialized_end = 7000
-    _STATAPPROXQUANTILE._serialized_start = 7003
-    _STATAPPROXQUANTILE._serialized_end = 7167
-    _STATFREQITEMS._serialized_start = 7169
-    _STATFREQITEMS._serialized_end = 7294
-    _STATSAMPLEBY._serialized_start = 7297
-    _STATSAMPLEBY._serialized_end = 7606
-    _STATSAMPLEBY_FRACTION._serialized_start = 7498
-    _STATSAMPLEBY_FRACTION._serialized_end = 7597
-    _NAFILL._serialized_start = 7609
-    _NAFILL._serialized_end = 7743
-    _NADROP._serialized_start = 7746
-    _NADROP._serialized_end = 7880
-    _NAREPLACE._serialized_start = 7883
-    _NAREPLACE._serialized_end = 8179
-    _NAREPLACE_REPLACEMENT._serialized_start = 8038
-    _NAREPLACE_REPLACEMENT._serialized_end = 8179
-    _TODF._serialized_start = 8181
-    _TODF._serialized_end = 8269
-    _WITHCOLUMNSRENAMED._serialized_start = 8272
-    _WITHCOLUMNSRENAMED._serialized_end = 8511
-    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8444
-    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8511
-    _WITHCOLUMNS._serialized_start = 8513
-    _WITHCOLUMNS._serialized_end = 8632
-    _HINT._serialized_start = 8635
-    _HINT._serialized_end = 8767
-    _UNPIVOT._serialized_start = 8770
-    _UNPIVOT._serialized_end = 9097
-    _UNPIVOT_VALUES._serialized_start = 9027
-    _UNPIVOT_VALUES._serialized_end = 9086
-    _TOSCHEMA._serialized_start = 9099
-    _TOSCHEMA._serialized_end = 9205
-    _REPARTITIONBYEXPRESSION._serialized_start = 9208
-    _REPARTITIONBYEXPRESSION._serialized_end = 9411
+    _RELATIONCOMMON._serialized_end = 2566
+    _SQL._serialized_start = 2569
+    _SQL._serialized_end = 2703
+    _SQL_ARGSENTRY._serialized_start = 2648
+    _SQL_ARGSENTRY._serialized_end = 2703
+    _READ._serialized_start = 2706
+    _READ._serialized_end = 3154
+    _READ_NAMEDTABLE._serialized_start = 2848
+    _READ_NAMEDTABLE._serialized_end = 2909
+    _READ_DATASOURCE._serialized_start = 2912
+    _READ_DATASOURCE._serialized_end = 3141
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3072
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3130
+    _PROJECT._serialized_start = 3156
+    _PROJECT._serialized_end = 3273
+    _FILTER._serialized_start = 3275
+    _FILTER._serialized_end = 3387
+    _JOIN._serialized_start = 3390
+    _JOIN._serialized_end = 3861
+    _JOIN_JOINTYPE._serialized_start = 3653
+    _JOIN_JOINTYPE._serialized_end = 3861
+    _SETOPERATION._serialized_start = 3864
+    _SETOPERATION._serialized_end = 4343
+    _SETOPERATION_SETOPTYPE._serialized_start = 4180
+    _SETOPERATION_SETOPTYPE._serialized_end = 4294
+    _LIMIT._serialized_start = 4345
+    _LIMIT._serialized_end = 4421
+    _OFFSET._serialized_start = 4423
+    _OFFSET._serialized_end = 4502
+    _TAIL._serialized_start = 4504
+    _TAIL._serialized_end = 4579
+    _AGGREGATE._serialized_start = 4582
+    _AGGREGATE._serialized_end = 5164
+    _AGGREGATE_PIVOT._serialized_start = 4921
+    _AGGREGATE_PIVOT._serialized_end = 5032
+    _AGGREGATE_GROUPTYPE._serialized_start = 5035
+    _AGGREGATE_GROUPTYPE._serialized_end = 5164
+    _SORT._serialized_start = 5167
+    _SORT._serialized_end = 5327
+    _DROP._serialized_start = 5329
+    _DROP._serialized_end = 5429
+    _DEDUPLICATE._serialized_start = 5432
+    _DEDUPLICATE._serialized_end = 5603
+    _LOCALRELATION._serialized_start = 5605
+    _LOCALRELATION._serialized_end = 5694
+    _SAMPLE._serialized_start = 5697
+    _SAMPLE._serialized_end = 5970
+    _RANGE._serialized_start = 5973
+    _RANGE._serialized_end = 6118
+    _SUBQUERYALIAS._serialized_start = 6120
+    _SUBQUERYALIAS._serialized_end = 6234
+    _REPARTITION._serialized_start = 6237
+    _REPARTITION._serialized_end = 6379
+    _SHOWSTRING._serialized_start = 6382
+    _SHOWSTRING._serialized_end = 6524
+    _STATSUMMARY._serialized_start = 6526
+    _STATSUMMARY._serialized_end = 6618
+    _STATDESCRIBE._serialized_start = 6620
+    _STATDESCRIBE._serialized_end = 6701
+    _STATCROSSTAB._serialized_start = 6703
+    _STATCROSSTAB._serialized_end = 6804
+    _STATCOV._serialized_start = 6806
+    _STATCOV._serialized_end = 6902
+    _STATCORR._serialized_start = 6905
+    _STATCORR._serialized_end = 7042
+    _STATAPPROXQUANTILE._serialized_start = 7045
+    _STATAPPROXQUANTILE._serialized_end = 7209
+    _STATFREQITEMS._serialized_start = 7211
+    _STATFREQITEMS._serialized_end = 7336
+    _STATSAMPLEBY._serialized_start = 7339
+    _STATSAMPLEBY._serialized_end = 7648
+    _STATSAMPLEBY_FRACTION._serialized_start = 7540
+    _STATSAMPLEBY_FRACTION._serialized_end = 7639
+    _NAFILL._serialized_start = 7651
+    _NAFILL._serialized_end = 7785
+    _NADROP._serialized_start = 7788
+    _NADROP._serialized_end = 7922
+    _NAREPLACE._serialized_start = 7925
+    _NAREPLACE._serialized_end = 8221
+    _NAREPLACE_REPLACEMENT._serialized_start = 8080
+    _NAREPLACE_REPLACEMENT._serialized_end = 8221
+    _TODF._serialized_start = 8223
+    _TODF._serialized_end = 8311
+    _WITHCOLUMNSRENAMED._serialized_start = 8314
+    _WITHCOLUMNSRENAMED._serialized_end = 8553
+    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8486
+    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8553
+    _WITHCOLUMNS._serialized_start = 8555
+    _WITHCOLUMNS._serialized_end = 8674
+    _HINT._serialized_start = 8677
+    _HINT._serialized_end = 8809
+    _UNPIVOT._serialized_start = 8812
+    _UNPIVOT._serialized_end = 9139
+    _UNPIVOT_VALUES._serialized_start = 9069
+    _UNPIVOT_VALUES._serialized_end = 9128
+    _TOSCHEMA._serialized_start = 9141
+    _TOSCHEMA._serialized_end = 9247
+    _REPARTITIONBYEXPRESSION._serialized_start = 9250
+    _REPARTITIONBYEXPRESSION._serialized_end = 9453
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 41962ee4062..b7cef7b299d 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -478,16 +478,29 @@ class RelationCommon(google.protobuf.message.Message):
     DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
     SOURCE_INFO_FIELD_NUMBER: builtins.int
+    PLAN_ID_FIELD_NUMBER: builtins.int
     source_info: builtins.str
     """(Required) Shared relation metadata."""
+    plan_id: builtins.int
+    """(Optional) A per-client globally unique id for a given connect plan."""
     def __init__(
         self,
         *,
         source_info: builtins.str = ...,
+        plan_id: builtins.int | None = ...,
     ) -> None: ...
+    def HasField(
+        self, field_name: typing_extensions.Literal["_plan_id", b"_plan_id", "plan_id", b"plan_id"]
+    ) -> builtins.bool: ...
     def ClearField(
-        self, field_name: typing_extensions.Literal["source_info", b"source_info"]
+        self,
+        field_name: typing_extensions.Literal[
+            "_plan_id", b"_plan_id", "plan_id", b"plan_id", "source_info", b"source_info"
+        ],
     ) -> None: ...
+    def WhichOneof(
+        self, oneof_group: typing_extensions.Literal["_plan_id", b"_plan_id"]
+    ) -> typing_extensions.Literal["plan_id"] | None: ...
 
 global___RelationCommon = RelationCommon
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index a723163cbe8..9e9341c9a2a 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -356,6 +356,68 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         )
         self.assert_eq(joined_plan3.toPandas(), joined_plan4.toPandas())
 
+    def test_join_ambiguous_cols(self):
+        # SPARK-41812: test join with ambiguous columns
+        data1 = [Row(id=1, value="foo"), Row(id=2, value=None)]
+        cdf1 = self.connect.createDataFrame(data1)
+        sdf1 = self.spark.createDataFrame(data1)
+
+        data2 = [Row(value="bar"), Row(value=None), Row(value="foo")]
+        cdf2 = self.connect.createDataFrame(data2)
+        sdf2 = self.spark.createDataFrame(data2)
+
+        cdf3 = cdf1.join(cdf2, cdf1["value"] == cdf2["value"])
+        sdf3 = sdf1.join(sdf2, sdf1["value"] == sdf2["value"])
+
+        self.assertEqual(cdf3.schema, sdf3.schema)
+        self.assertEqual(cdf3.collect(), sdf3.collect())
+
+        cdf4 = cdf1.join(cdf2, cdf1["value"].eqNullSafe(cdf2["value"]))
+        sdf4 = sdf1.join(sdf2, sdf1["value"].eqNullSafe(sdf2["value"]))
+
+        self.assertEqual(cdf4.schema, sdf4.schema)
+        self.assertEqual(cdf4.collect(), sdf4.collect())
+
+        cdf5 = cdf1.join(
+            cdf2, (cdf1["value"] == cdf2["value"]) & (cdf1["value"].eqNullSafe(cdf2["value"]))
+        )
+        sdf5 = sdf1.join(
+            sdf2, (sdf1["value"] == sdf2["value"]) & (sdf1["value"].eqNullSafe(sdf2["value"]))
+        )
+
+        self.assertEqual(cdf5.schema, sdf5.schema)
+        self.assertEqual(cdf5.collect(), sdf5.collect())
+
+        cdf6 = cdf1.join(cdf2, cdf1["value"] == cdf2["value"]).select(cdf1.value)
+        sdf6 = sdf1.join(sdf2, sdf1["value"] == sdf2["value"]).select(sdf1.value)
+
+        self.assertEqual(cdf6.schema, sdf6.schema)
+        self.assertEqual(cdf6.collect(), sdf6.collect())
+
+        cdf7 = cdf1.join(cdf2, cdf1["value"] == cdf2["value"]).select(cdf2.value)
+        sdf7 = sdf1.join(sdf2, sdf1["value"] == sdf2["value"]).select(sdf2.value)
+
+        self.assertEqual(cdf7.schema, sdf7.schema)
+        self.assertEqual(cdf7.collect(), sdf7.collect())
+
+    def test_invalid_column(self):
+        # SPARK-41812: fail df1.select(df2.col)
+        data1 = [Row(a=1, b=2, c=3)]
+        cdf1 = self.connect.createDataFrame(data1)
+
+        data2 = [Row(a=2, b=0)]
+        cdf2 = self.connect.createDataFrame(data2)
+
+        with self.assertRaises(AnalysisException):
+            cdf1.select(cdf2.a).schema
+
+        with self.assertRaises(AnalysisException):
+            cdf2.withColumn("x", cdf1.a + 1).schema
+
+        with self.assertRaisesRegex(AnalysisException, "attribute.*missing"):
+            cdf3 = cdf1.select(cdf1.a)
+            cdf3.select(cdf1.b).schema
+
     def test_collect(self):
         cdf = self.connect.read.table(self.tbl_name)
         sdf = self.spark.read.table(self.tbl_name)
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py b/python/pyspark/sql/tests/connect/test_connect_plan.py
index 1892e64f8f9..a5f691d0bef 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan.py
@@ -85,7 +85,18 @@ class SparkConnectPlanTests(PlanOnlyTestFixture):
         right_input = self.connect.readTable(table_name=self.tbl_name)
         crossJoin_plan = left_input.crossJoin(other=right_input)._plan.to_proto(self.connect)
         join_plan = left_input.join(other=right_input, how="cross")._plan.to_proto(self.connect)
-        self.assertEqual(crossJoin_plan, join_plan)
+        self.assertEqual(
+            crossJoin_plan.root.join.left.read.named_table,
+            join_plan.root.join.left.read.named_table,
+        )
+        self.assertEqual(
+            crossJoin_plan.root.join.right.read.named_table,
+            join_plan.root.join.right.read.named_table,
+        )
+        self.assertEqual(
+            crossJoin_plan.root.join.join_type,
+            join_plan.root.join.join_type,
+        )
 
     def test_filter(self):
         df = self.connect.readTable(table_name=self.tbl_name)
@@ -732,7 +743,12 @@ class SparkConnectPlanTests(PlanOnlyTestFixture):
 
         self.assertIsNotNone(cp1)
         self.assertEqual(cp1, cp2)
-        self.assertEqual(cp2, cp3)
+        self.assertEqual(
+            cp2.unresolved_attribute.unparsed_identifier,
+            cp3.unresolved_attribute.unparsed_identifier,
+        )
+        self.assertTrue(cp2.unresolved_attribute.HasField("plan_id"))
+        self.assertFalse(cp3.unresolved_attribute.HasField("plan_id"))
 
     def test_null_literal(self):
         null_lit = lit(None)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 9a2648a79a5..eff8c114a97 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -3307,12 +3307,14 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
       _.containsPattern(NATURAL_LIKE_JOIN), ruleId) {
       case j @ Join(left, right, UsingJoin(joinType, usingCols), _, hint)
           if left.resolved && right.resolved && j.duplicateResolved =>
-        commonNaturalJoinProcessing(left, right, joinType, usingCols, None, hint)
+        commonNaturalJoinProcessing(left, right, joinType, usingCols, None, hint,
+          j.getTagValue(LogicalPlan.PLAN_ID_TAG))
       case j @ Join(left, right, NaturalJoin(joinType), condition, hint)
           if j.resolvedExceptNatural =>
         // find common column names from both sides
         val joinNames = left.output.map(_.name).intersect(right.output.map(_.name))
-        commonNaturalJoinProcessing(left, right, joinType, joinNames, condition, hint)
+        commonNaturalJoinProcessing(left, right, joinType, joinNames, condition, hint,
+          j.getTagValue(LogicalPlan.PLAN_ID_TAG))
     }
   }
 
@@ -3442,7 +3444,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
       joinType: JoinType,
       joinNames: Seq[String],
       condition: Option[Expression],
-      hint: JoinHint): LogicalPlan = {
+      hint: JoinHint,
+      planId: Option[Long] = None): LogicalPlan = {
     import org.apache.spark.sql.catalyst.util._
 
     val leftKeys = joinNames.map { keyName =>
@@ -3483,9 +3486,14 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
       case _ =>
         throw QueryExecutionErrors.unsupportedNaturalJoinTypeError(joinType)
     }
+
+    val newJoin = Join(left, right, joinType, newCondition, hint)
+    // retain the plan id used in Spark Connect
+    planId.foreach(newJoin.setTagValue(LogicalPlan.PLAN_ID_TAG, _))
+
     // use Project to hide duplicated common keys
     // propagate hidden columns from nested USING/NATURAL JOINs
-    val project = Project(projectList, Join(left, right, joinType, newCondition, hint))
+    val project = Project(projectList, newJoin)
     project.setTagValue(
       Project.hiddenOutputTag,
       hiddenList.map(_.markAsQualifiedAccessOnly()) ++
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
index 0985fe6852c..ba550bce791 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
@@ -357,8 +357,21 @@ trait ColumnResolutionHelper extends Logging {
       e: Expression,
       q: LogicalPlan,
       allowOuter: Boolean = false): Expression = {
+    val newE = if (e.exists(_.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty)) {
+      // If the TreeNodeTag 'LogicalPlan.PLAN_ID_TAG' is attached, it means that the plan and
+      // expression are from Spark Connect, and need to be resolved in this way:
+      //    1, extract the attached plan id from the expression (UnresolvedAttribute only for now);
+      //    2, top-down traverse the query plan to find the plan node that matches the plan id;
+      //    3, if can not find the matching node, fail the analysis due to illegal references;
+      //    4, resolve the expression with the matching node, if any error occurs here, apply the
+      //    old code path;
+      resolveExpressionByPlanId(e, q)
+    } else {
+      e
+    }
+
     resolveExpression(
-      e,
+      newE,
       resolveColumnByName = nameParts => {
         q.resolveChildren(nameParts, conf.resolver)
       },
@@ -369,4 +382,47 @@ trait ColumnResolutionHelper extends Logging {
       throws = true,
       allowOuter = allowOuter)
   }
+
+  private def resolveExpressionByPlanId(
+      e: Expression,
+      q: LogicalPlan): Expression = {
+    if (!e.exists(_.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty)) {
+      return e
+    }
+
+    e match {
+      case u: UnresolvedAttribute =>
+        resolveUnresolvedAttributeByPlanId(u, q).getOrElse(u)
+      case _ =>
+        e.mapChildren(c => resolveExpressionByPlanId(c, q))
+    }
+  }
+
+  private def resolveUnresolvedAttributeByPlanId(
+      u: UnresolvedAttribute,
+      q: LogicalPlan): Option[NamedExpression] = {
+    val planIdOpt = u.getTagValue(LogicalPlan.PLAN_ID_TAG)
+    if (planIdOpt.isEmpty) return None
+    val planId = planIdOpt.get
+    logDebug(s"Extract plan_id $planId from $u")
+
+    val planOpt = q.find(_.getTagValue(LogicalPlan.PLAN_ID_TAG).contains(planId))
+    if (planOpt.isEmpty) {
+      // For example:
+      //  df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]])
+      //  df2 = spark.createDataFrame([Row(a = 1, b = 2)]])
+      //  df1.select(df2.a)   <-   illegal reference df2.a
+      throw new AnalysisException(s"When resolving $u, " +
+        s"fail to find subplan with plan_id=$planId in $q")
+    }
+    val plan = planOpt.get
+
+    try {
+      plan.resolve(u.nameParts, conf.resolver)
+    } catch {
+      case e: AnalysisException =>
+        logDebug(s"Fail to resolve $u with $plan due to $e")
+        None
+    }
+  }
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 5a7dcff3667..36187bb2d55 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.{AliasAwareQueryOutputOrdering, QueryPlan}
 import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.LogicalPlanStats
-import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, UnaryLike}
+import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, TreeNodeTag, UnaryLike}
 import org.apache.spark.sql.catalyst.util.MetadataColumnHelper
 import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
 import org.apache.spark.sql.types.{DataType, StructType}
@@ -157,6 +157,18 @@ abstract class LogicalPlan
   }
 }
 
+object LogicalPlan {
+  // A dedicated tag for Spark Connect.
+  // If an expression (only support UnresolvedAttribute for now) was attached by this tag,
+  // the analyzer will:
+  //    1, extract the plan id;
+  //    2, top-down traverse the query plan to find the node that was attached by the same tag.
+  //    and fails the whole analysis if can not find it;
+  //    3, resolve this expression with the matching node. If any error occurs, analyzer fallbacks
+  //    to the old code path.
+  private[spark] val PLAN_ID_TAG = TreeNodeTag[Long]("plan_id")
+}
+
 /**
  * A logical plan node with no children.
  */


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org