You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2022/11/16 09:20:29 UTC

[spark] branch master updated: [SPARK-40809][CONNECT][FOLLOW] Support `alias()` in Python client

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0f7eaeee644 [SPARK-40809][CONNECT][FOLLOW] Support `alias()` in Python client
0f7eaeee644 is described below

commit 0f7eaeee6445aaf05310229bdec22f6953113f2e
Author: Martin Grund <ma...@databricks.com>
AuthorDate: Wed Nov 16 18:20:05 2022 +0900

    [SPARK-40809][CONNECT][FOLLOW] Support `alias()` in Python client
    
    ### What changes were proposed in this pull request?
    This extends the implementation of column aliases in Spark Connect with supporting lists of column names and providing the appropriate implementation for the Python side.
    
    ### Why are the changes needed?
    Compatibility
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    UT in Python and Scala
    
    Closes #38631 from grundprinzip/SPARK-40809-f.
    
    Lead-authored-by: Martin Grund <ma...@databricks.com>
    Co-authored-by: Martin Grund <gr...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 .../main/protobuf/spark/connect/expressions.proto  |  4 +-
 .../org/apache/spark/sql/connect/dsl/package.scala | 30 +++++++++-
 .../connect/planner/DataTypeProtoConverter.scala   | 20 ++++++-
 .../sql/connect/planner/SparkConnectPlanner.scala  | 21 ++++++-
 .../connect/planner/SparkConnectProtoSuite.scala   | 37 +++++++++++-
 python/pyspark/sql/connect/column.py               | 66 ++++++++++++++++++++++
 python/pyspark/sql/connect/dataframe.py            |  2 +-
 .../pyspark/sql/connect/proto/expressions_pb2.py   |  6 +-
 .../pyspark/sql/connect/proto/expressions_pb2.pyi  | 24 ++++++--
 .../sql/tests/connect/test_connect_basic.py        | 16 +++++-
 .../connect/test_connect_column_expressions.py     | 10 ++++
 python/pyspark/testing/connectutils.py             |  4 ++
 12 files changed, 223 insertions(+), 17 deletions(-)

diff --git a/connector/connect/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/src/main/protobuf/spark/connect/expressions.proto
index 426e04341d9..ac5fe24d349 100644
--- a/connector/connect/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/expressions.proto
@@ -170,6 +170,8 @@ message Expression {
 
   message Alias {
     Expression expr = 1;
-    string name = 2;
+    repeated string name = 2;
+    // Alias metadata expressed as a JSON map.
+    optional string metadata = 3;
   }
 }
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 98879b69b84..caff3d8f071 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -79,7 +79,28 @@ package object dsl {
     implicit class DslExpression(val expr: Expression) {
       def as(alias: String): Expression = Expression
         .newBuilder()
-        .setAlias(Expression.Alias.newBuilder().setName(alias).setExpr(expr))
+        .setAlias(Expression.Alias.newBuilder().addName(alias).setExpr(expr))
+        .build()
+
+      def as(alias: String, metadata: String): Expression = Expression
+        .newBuilder()
+        .setAlias(
+          Expression.Alias
+            .newBuilder()
+            .setExpr(expr)
+            .addName(alias)
+            .setMetadata(metadata)
+            .build())
+        .build()
+
+      def as(alias: Seq[String]): Expression = Expression
+        .newBuilder()
+        .setAlias(
+          Expression.Alias
+            .newBuilder()
+            .setExpr(expr)
+            .addAllName(alias.asJava)
+            .build())
         .build()
 
       def <(other: Expression): Expression =
@@ -101,6 +122,13 @@ package object dsl {
           Expression.UnresolvedFunction.newBuilder().addParts("min").addArguments(e))
         .build()
 
+    def proto_explode(e: Expression): Expression =
+      Expression
+        .newBuilder()
+        .setUnresolvedFunction(
+          Expression.UnresolvedFunction.newBuilder().addParts("explode").addArguments(e))
+        .build()
+
     /**
      * Create an unresolved function from name parts.
      *
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
index 0ee90b5e8fb..088030b2dbc 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
@@ -21,7 +21,7 @@ import scala.collection.convert.ImplicitConversions._
 
 import org.apache.spark.connect.proto
 import org.apache.spark.sql.SaveMode
-import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructField, StructType}
+import org.apache.spark.sql.types.{DataType, IntegerType, LongType, MapType, StringType, StructField, StructType}
 
 /**
  * This object offers methods to convert to/from connect proto to catalyst types.
@@ -32,6 +32,7 @@ object DataTypeProtoConverter {
       case proto.DataType.KindCase.I32 => IntegerType
       case proto.DataType.KindCase.STRING => StringType
       case proto.DataType.KindCase.STRUCT => convertProtoDataTypeToCatalyst(t.getStruct)
+      case proto.DataType.KindCase.MAP => convertProtoDataTypeToCatalyst(t.getMap)
       case _ =>
         throw InvalidPlanInput(s"Does not support convert ${t.getKindCase} to catalyst types.")
     }
@@ -44,6 +45,10 @@ object DataTypeProtoConverter {
     StructType.apply(structFields)
   }
 
+  private def convertProtoDataTypeToCatalyst(t: proto.DataType.Map): MapType = {
+    MapType(toCatalystType(t.getKey), toCatalystType(t.getValue))
+  }
+
   def toConnectProtoType(t: DataType): proto.DataType = {
     t match {
       case IntegerType =>
@@ -54,11 +59,24 @@ object DataTypeProtoConverter {
         proto.DataType.newBuilder().setI64(proto.DataType.I64.getDefaultInstance).build()
       case struct: StructType =>
         toConnectProtoStructType(struct)
+      case map: MapType => toConnectProtoMapType(map)
       case _ =>
         throw InvalidPlanInput(s"Does not support convert ${t.typeName} to connect proto types.")
     }
   }
 
+  def toConnectProtoMapType(schema: MapType): proto.DataType = {
+    proto.DataType
+      .newBuilder()
+      .setMap(
+        proto.DataType.Map
+          .newBuilder()
+          .setKey(toConnectProtoType(schema.keyType))
+          .setValue(toConnectProtoType(schema.valueType))
+          .build())
+      .build()
+  }
+
   def toConnectProtoStructType(schema: StructType): proto.DataType = {
     val struct = proto.DataType.Struct.newBuilder()
     for (structField <- schema.fields) {
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index f5801d57c29..232f6e10474 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -27,7 +27,7 @@ import org.apache.spark.connect.proto
 import org.apache.spark.connect.proto.WriteOperation
 import org.apache.spark.sql.{Dataset, SparkSession}
 import org.apache.spark.sql.catalyst.AliasIdentifier
-import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
+import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
 import org.apache.spark.sql.catalyst.expressions
 import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, NamedExpression}
 import org.apache.spark.sql.catalyst.optimizer.CombineUnions
@@ -338,7 +338,9 @@ class SparkConnectPlanner(session: SparkSession) {
       case proto.Expression.ExprTypeCase.ALIAS => transformAlias(exp.getAlias)
       case proto.Expression.ExprTypeCase.EXPRESSION_STRING =>
         transformExpressionString(exp.getExpressionString)
-      case _ => throw InvalidPlanInput()
+      case _ =>
+        throw InvalidPlanInput(
+          s"Expression with ID: ${exp.getExprTypeCase.getNumber} is not supported")
     }
   }
 
@@ -412,7 +414,20 @@ class SparkConnectPlanner(session: SparkSession) {
   }
 
   private def transformAlias(alias: proto.Expression.Alias): NamedExpression = {
-    Alias(transformExpression(alias.getExpr), alias.getName)()
+    if (alias.getNameCount == 1) {
+      val md = if (alias.hasMetadata()) {
+        Some(Metadata.fromJson(alias.getMetadata))
+      } else {
+        None
+      }
+      Alias(transformExpression(alias.getExpr), alias.getName(0))(explicitMetadata = md)
+    } else {
+      if (alias.hasMetadata) {
+        throw new InvalidPlanInput(
+          "Alias expressions with more than 1 identifier must not use optional metadata.")
+      }
+      MultiAlias(transformExpression(alias.getExpr), alias.getNameList.asScala.toSeq)
+    }
   }
 
   private def transformExpressionString(expr: proto.Expression.ExpressionString): Expression = {
diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 71633830b56..404581445d0 100644
--- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.connect.dsl.commands._
 import org.apache.spark.sql.connect.dsl.expressions._
 import org.apache.spark.sql.connect.dsl.plans._
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
+import org.apache.spark.sql.types.{IntegerType, MapType, Metadata, StringType, StructField, StructType}
 
 /**
  * This suite is based on connect DSL and test that given same dataframe operations, whether
@@ -50,6 +50,9 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
     createLocalRelationProto(
       Seq(AttributeReference("id", IntegerType)(), AttributeReference("name", StringType)()))
 
+  lazy val connectTestRelationMap =
+    createLocalRelationProto(Seq(AttributeReference("id", MapType(StringType, StringType))()))
+
   lazy val sparkTestRelation: DataFrame =
     spark.createDataFrame(
       new java.util.ArrayList[Row](),
@@ -60,6 +63,11 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
       new java.util.ArrayList[Row](),
       StructType(Seq(StructField("id", IntegerType), StructField("name", StringType))))
 
+  lazy val sparkTestRelationMap: DataFrame =
+    spark.createDataFrame(
+      new java.util.ArrayList[Row](),
+      StructType(Seq(StructField("id", MapType(StringType, StringType)))))
+
   lazy val localRelation = createLocalRelationProto(Seq(AttributeReference("id", IntegerType)()))
 
   test("Basic select") {
@@ -140,10 +148,35 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
     comparePlans(connectPlan2, sparkPlan2)
   }
 
-  test("column alias") {
+  test("SPARK-40809: column alias") {
+    // Simple Test.
     val connectPlan = connectTestRelation.select("id".protoAttr.as("id2"))
     val sparkPlan = sparkTestRelation.select(Column("id").alias("id2"))
     comparePlans(connectPlan, sparkPlan)
+
+    // Scalar columns with metadata
+    val mdJson = "{\"max\": 99}"
+    comparePlans(
+      connectTestRelation.select("id".protoAttr.as("id2", mdJson)),
+      sparkTestRelation.select(Column("id").as("id2", Metadata.fromJson(mdJson))))
+
+    comparePlans(
+      connectTestRelationMap.select(proto_explode("id".protoAttr).as(Seq("a", "b"))),
+      sparkTestRelationMap.select(explode(Column("id")).as(Seq("a", "b"))))
+
+    // Metadata must only be specified for regular Aliases.
+    assertThrows[InvalidPlanInput] {
+      val attr = proto_explode("id".protoAttr)
+      val alias = proto.Expression.Alias
+        .newBuilder()
+        .setExpr(attr)
+        .addName("a")
+        .addName("b")
+        .setMetadata(mdJson)
+        .build()
+      transform(
+        connectTestRelationMap.select(proto.Expression.newBuilder().setAlias(alias).build()))
+    }
   }
 
   test("Aggregate with more than 1 grouping expressions") {
diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py
index 417bc7097de..9f610bf18fe 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -17,6 +17,7 @@
 import uuid
 from typing import cast, get_args, TYPE_CHECKING, Callable, Any
 
+import json
 import decimal
 import datetime
 
@@ -82,6 +83,71 @@ class Expression(object):
     def __str__(self) -> str:
         ...
 
+    def alias(self, *alias: str, **kwargs: Any) -> "ColumnAlias":
+        """
+        Returns this column aliased with a new name or names (in the case of expressions that
+        return more than one column, such as explode).
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        alias : str
+            desired column names (collects all positional arguments passed)
+
+        Other Parameters
+        ----------------
+        metadata: dict
+            a dict of information to be stored in ``metadata`` attribute of the
+            corresponding :class:`StructField <pyspark.sql.types.StructField>` (optional, keyword
+            only argument)
+
+        Returns
+        -------
+        :class:`Column`
+            Column representing whether each element of Column is aliased with new name or names.
+
+        Examples
+        --------
+        >>> df = spark.createDataFrame(
+        ...      [(2, "Alice"), (5, "Bob")], ["age", "name"])
+        >>> df.select(df.age.alias("age2")).collect()
+        [Row(age2=2), Row(age2=5)]
+        >>> df.select(df.age.alias("age3", metadata={'max': 99})).schema['age3'].metadata['max']
+        99
+        """
+        metadata = kwargs.pop("metadata", None)
+        assert not kwargs, "Unexpected kwargs where passed: %s" % kwargs
+        return ColumnAlias(self, list(alias), metadata)
+
+
+class ColumnAlias(Expression):
+    def __init__(self, parent: Expression, alias: list[str], metadata: Any):
+
+        self._alias = alias
+        self._metadata = metadata
+        self._parent = parent
+
+    def to_plan(self, session: "RemoteSparkSession") -> "proto.Expression":
+        if len(self._alias) == 1:
+            exp = proto.Expression()
+            exp.alias.name.append(self._alias[0])
+            exp.alias.expr.CopyFrom(self._parent.to_plan(session))
+
+            if self._metadata:
+                exp.alias.metadata = json.dumps(self._metadata)
+            return exp
+        else:
+            if self._metadata:
+                raise ValueError("metadata can only be provided for a single column")
+            exp = proto.Expression()
+            exp.alias.name.extend(self._alias)
+            exp.alias.expr.CopyFrom(self._parent.to_plan(session))
+            return exp
+
+    def __str__(self) -> str:
+        return f"Alias({self._parent}, ({','.join(self._alias)}))"
+
 
 class LiteralExpression(Expression):
     """A literal expression.
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index e2334b66c68..5dd28c0e6a9 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -120,7 +120,7 @@ class DataFrame(object):
         new_frame._plan = plan
         return new_frame
 
-    def select(self, *cols: ColumnOrName) -> "DataFrame":
+    def select(self, *cols: "ExpressionOrString") -> "DataFrame":
         return DataFrame.withPlan(plan.Project(self._plan, *cols), session=self._session)
 
     def agg(self, exprs: Optional[GroupingFrame.MeasuresType]) -> "DataFrame":
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 89718650571..dca9d2cef47 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -33,7 +33,7 @@ from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto\x1a\x19google/protobuf/any.proto"\xc2\x17\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFu [...]
+    b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto\x1a\x19google/protobuf/any.proto"\xf0\x17\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFu [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -43,7 +43,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     DESCRIPTOR._options = None
     DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001"
     _EXPRESSION._serialized_start = 105
-    _EXPRESSION._serialized_end = 3115
+    _EXPRESSION._serialized_end = 3161
     _EXPRESSION_LITERAL._serialized_start = 613
     _EXPRESSION_LITERAL._serialized_end = 2696
     _EXPRESSION_LITERAL_VARCHAR._serialized_start = 1923
@@ -75,5 +75,5 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _EXPRESSION_QUALIFIEDATTRIBUTE._serialized_start = 2941
     _EXPRESSION_QUALIFIEDATTRIBUTE._serialized_end = 3026
     _EXPRESSION_ALIAS._serialized_start = 3028
-    _EXPRESSION_ALIAS._serialized_end = 3102
+    _EXPRESSION_ALIAS._serialized_end = 3148
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index ffaf03eade8..ea538b2ebec 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -633,21 +633,37 @@ class Expression(google.protobuf.message.Message):
 
         EXPR_FIELD_NUMBER: builtins.int
         NAME_FIELD_NUMBER: builtins.int
+        METADATA_FIELD_NUMBER: builtins.int
         @property
         def expr(self) -> global___Expression: ...
-        name: builtins.str
+        @property
+        def name(
+            self,
+        ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ...
+        metadata: builtins.str
+        """Alias metadata expressed as a JSON map."""
         def __init__(
             self,
             *,
             expr: global___Expression | None = ...,
-            name: builtins.str = ...,
+            name: collections.abc.Iterable[builtins.str] | None = ...,
+            metadata: builtins.str | None = ...,
         ) -> None: ...
         def HasField(
-            self, field_name: typing_extensions.Literal["expr", b"expr"]
+            self,
+            field_name: typing_extensions.Literal[
+                "_metadata", b"_metadata", "expr", b"expr", "metadata", b"metadata"
+            ],
         ) -> builtins.bool: ...
         def ClearField(
-            self, field_name: typing_extensions.Literal["expr", b"expr", "name", b"name"]
+            self,
+            field_name: typing_extensions.Literal[
+                "_metadata", b"_metadata", "expr", b"expr", "metadata", b"metadata", "name", b"name"
+            ],
         ) -> None: ...
+        def WhichOneof(
+            self, oneof_group: typing_extensions.Literal["_metadata", b"_metadata"]
+        ) -> typing_extensions.Literal["metadata"] | None: ...
 
     LITERAL_FIELD_NUMBER: builtins.int
     UNRESOLVED_ATTRIBUTE_FIELD_NUMBER: builtins.int
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index d31e472faec..a49829cc085 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -33,7 +33,7 @@ from pyspark.sql.types import StructType, StructField, LongType, StringType
 if have_pandas:
     from pyspark.sql.connect.client import RemoteSparkSession, ChannelBuilder
     from pyspark.sql.connect.function_builder import udf
-    from pyspark.sql.connect.functions import lit
+    from pyspark.sql.connect.functions import lit, col
 from pyspark.sql.dataframe import DataFrame
 from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
 from pyspark.testing.pandasutils import PandasOnSparkTestCase
@@ -291,6 +291,20 @@ class SparkConnectTests(SparkConnectSQLTestCase):
             actualResult = pandasResult.values.tolist()
             self.assertEqual(len(expectResult), len(actualResult))
 
+    def test_alias(self) -> None:
+        """Testing supported and unsupported alias"""
+        col0 = (
+            self.connect.range(1, 10)
+            .select(col("id").alias("name", metadata={"max": 99}))
+            .schema()
+            .names[0]
+        )
+        self.assertEqual("name", col0)
+
+        with self.assertRaises(grpc.RpcError) as exc:
+            self.connect.range(1, 10).select(col("id").alias("this", "is", "not")).collect()
+        self.assertIn("Buffer(this, is, not)", str(exc.exception))
+
 
 class ChannelBuilderTests(ReusedPySparkTestCase):
     def test_invalid_connection_strings(self):
diff --git a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
index 59e3c97679e..99b63482a24 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
@@ -134,6 +134,16 @@ class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture):
         lit_list_plan = fun.lit([fun.lit(10), fun.lit("str")]).to_plan(None)
         self.assertIsNotNone(lit_list_plan)
 
+    def test_column_alias(self) -> None:
+        # SPARK-40809: Support for Column Aliases
+        col0 = fun.col("a").alias("martin")
+        self.assertEqual("Alias(Column(a), (martin))", str(col0))
+
+        col0 = fun.col("a").alias("martin", metadata={"pii": True})
+        plan = col0.to_plan(self.session)
+        self.assertIsNotNone(plan)
+        self.assertEqual(plan.alias.metadata, '{"pii": true}')
+
     def test_column_expressions(self):
         """Test a more complex combination of expressions and their translation into
         the protobuf structure."""
diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py
index b7c49a6df54..f98a67b9964 100644
--- a/python/pyspark/testing/connectutils.py
+++ b/python/pyspark/testing/connectutils.py
@@ -18,6 +18,7 @@ import os
 from typing import Any, Dict, Optional
 import functools
 import unittest
+
 from pyspark.testing.sqlutils import have_pandas
 
 if have_pandas:
@@ -25,6 +26,7 @@ if have_pandas:
     from pyspark.sql.connect.plan import Read, Range, SQL
     from pyspark.testing.utils import search_jar
     from pyspark.sql.connect.plan import LogicalPlan
+    from pyspark.sql.connect.client import RemoteSparkSession
 
     connect_jar = search_jar("connector/connect", "spark-connect-assembly-", "spark-connect")
 else:
@@ -67,6 +69,7 @@ class MockRemoteSession:
 class PlanOnlyTestFixture(unittest.TestCase):
 
     connect: "MockRemoteSession"
+    session: RemoteSparkSession
 
     @classmethod
     def _read_table(cls, table_name: str) -> "DataFrame":
@@ -99,6 +102,7 @@ class PlanOnlyTestFixture(unittest.TestCase):
     @classmethod
     def setUpClass(cls: Any) -> None:
         cls.connect = MockRemoteSession()
+        cls.session = RemoteSparkSession()
         cls.tbl_name = "test_connect_plan_only_table_1"
 
         cls.connect.set_hook("register_udf", cls._udf_mock)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org