You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2022/11/16 09:20:29 UTC
[spark] branch master updated: [SPARK-40809][CONNECT][FOLLOW] Support `alias()` in Python client
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 0f7eaeee644 [SPARK-40809][CONNECT][FOLLOW] Support `alias()` in Python client
0f7eaeee644 is described below
commit 0f7eaeee6445aaf05310229bdec22f6953113f2e
Author: Martin Grund <ma...@databricks.com>
AuthorDate: Wed Nov 16 18:20:05 2022 +0900
[SPARK-40809][CONNECT][FOLLOW] Support `alias()` in Python client
### What changes were proposed in this pull request?
This extends the implementation of column aliases in Spark Connect with supporting lists of column names and providing the appropriate implementation for the Python side.
### Why are the changes needed?
Compatibility
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
UT in Python and Scala
Closes #38631 from grundprinzip/SPARK-40809-f.
Lead-authored-by: Martin Grund <ma...@databricks.com>
Co-authored-by: Martin Grund <gr...@gmail.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
.../main/protobuf/spark/connect/expressions.proto | 4 +-
.../org/apache/spark/sql/connect/dsl/package.scala | 30 +++++++++-
.../connect/planner/DataTypeProtoConverter.scala | 20 ++++++-
.../sql/connect/planner/SparkConnectPlanner.scala | 21 ++++++-
.../connect/planner/SparkConnectProtoSuite.scala | 37 +++++++++++-
python/pyspark/sql/connect/column.py | 66 ++++++++++++++++++++++
python/pyspark/sql/connect/dataframe.py | 2 +-
.../pyspark/sql/connect/proto/expressions_pb2.py | 6 +-
.../pyspark/sql/connect/proto/expressions_pb2.pyi | 24 ++++++--
.../sql/tests/connect/test_connect_basic.py | 16 +++++-
.../connect/test_connect_column_expressions.py | 10 ++++
python/pyspark/testing/connectutils.py | 4 ++
12 files changed, 223 insertions(+), 17 deletions(-)
diff --git a/connector/connect/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/src/main/protobuf/spark/connect/expressions.proto
index 426e04341d9..ac5fe24d349 100644
--- a/connector/connect/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/expressions.proto
@@ -170,6 +170,8 @@ message Expression {
message Alias {
Expression expr = 1;
- string name = 2;
+ repeated string name = 2;
+ // Alias metadata expressed as a JSON map.
+ optional string metadata = 3;
}
}
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 98879b69b84..caff3d8f071 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -79,7 +79,28 @@ package object dsl {
implicit class DslExpression(val expr: Expression) {
def as(alias: String): Expression = Expression
.newBuilder()
- .setAlias(Expression.Alias.newBuilder().setName(alias).setExpr(expr))
+ .setAlias(Expression.Alias.newBuilder().addName(alias).setExpr(expr))
+ .build()
+
+ def as(alias: String, metadata: String): Expression = Expression
+ .newBuilder()
+ .setAlias(
+ Expression.Alias
+ .newBuilder()
+ .setExpr(expr)
+ .addName(alias)
+ .setMetadata(metadata)
+ .build())
+ .build()
+
+ def as(alias: Seq[String]): Expression = Expression
+ .newBuilder()
+ .setAlias(
+ Expression.Alias
+ .newBuilder()
+ .setExpr(expr)
+ .addAllName(alias.asJava)
+ .build())
.build()
def <(other: Expression): Expression =
@@ -101,6 +122,13 @@ package object dsl {
Expression.UnresolvedFunction.newBuilder().addParts("min").addArguments(e))
.build()
+ def proto_explode(e: Expression): Expression =
+ Expression
+ .newBuilder()
+ .setUnresolvedFunction(
+ Expression.UnresolvedFunction.newBuilder().addParts("explode").addArguments(e))
+ .build()
+
/**
* Create an unresolved function from name parts.
*
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
index 0ee90b5e8fb..088030b2dbc 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/DataTypeProtoConverter.scala
@@ -21,7 +21,7 @@ import scala.collection.convert.ImplicitConversions._
import org.apache.spark.connect.proto
import org.apache.spark.sql.SaveMode
-import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType, StructField, StructType}
+import org.apache.spark.sql.types.{DataType, IntegerType, LongType, MapType, StringType, StructField, StructType}
/**
* This object offers methods to convert to/from connect proto to catalyst types.
@@ -32,6 +32,7 @@ object DataTypeProtoConverter {
case proto.DataType.KindCase.I32 => IntegerType
case proto.DataType.KindCase.STRING => StringType
case proto.DataType.KindCase.STRUCT => convertProtoDataTypeToCatalyst(t.getStruct)
+ case proto.DataType.KindCase.MAP => convertProtoDataTypeToCatalyst(t.getMap)
case _ =>
throw InvalidPlanInput(s"Does not support convert ${t.getKindCase} to catalyst types.")
}
@@ -44,6 +45,10 @@ object DataTypeProtoConverter {
StructType.apply(structFields)
}
+ private def convertProtoDataTypeToCatalyst(t: proto.DataType.Map): MapType = {
+ MapType(toCatalystType(t.getKey), toCatalystType(t.getValue))
+ }
+
def toConnectProtoType(t: DataType): proto.DataType = {
t match {
case IntegerType =>
@@ -54,11 +59,24 @@ object DataTypeProtoConverter {
proto.DataType.newBuilder().setI64(proto.DataType.I64.getDefaultInstance).build()
case struct: StructType =>
toConnectProtoStructType(struct)
+ case map: MapType => toConnectProtoMapType(map)
case _ =>
throw InvalidPlanInput(s"Does not support convert ${t.typeName} to connect proto types.")
}
}
+ def toConnectProtoMapType(schema: MapType): proto.DataType = {
+ proto.DataType
+ .newBuilder()
+ .setMap(
+ proto.DataType.Map
+ .newBuilder()
+ .setKey(toConnectProtoType(schema.keyType))
+ .setValue(toConnectProtoType(schema.valueType))
+ .build())
+ .build()
+ }
+
def toConnectProtoStructType(schema: StructType): proto.DataType = {
val struct = proto.DataType.Struct.newBuilder()
for (structField <- schema.fields) {
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index f5801d57c29..232f6e10474 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -27,7 +27,7 @@ import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.WriteOperation
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.catalyst.AliasIdentifier
-import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
+import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
@@ -338,7 +338,9 @@ class SparkConnectPlanner(session: SparkSession) {
case proto.Expression.ExprTypeCase.ALIAS => transformAlias(exp.getAlias)
case proto.Expression.ExprTypeCase.EXPRESSION_STRING =>
transformExpressionString(exp.getExpressionString)
- case _ => throw InvalidPlanInput()
+ case _ =>
+ throw InvalidPlanInput(
+ s"Expression with ID: ${exp.getExprTypeCase.getNumber} is not supported")
}
}
@@ -412,7 +414,20 @@ class SparkConnectPlanner(session: SparkSession) {
}
private def transformAlias(alias: proto.Expression.Alias): NamedExpression = {
- Alias(transformExpression(alias.getExpr), alias.getName)()
+ if (alias.getNameCount == 1) {
+ val md = if (alias.hasMetadata()) {
+ Some(Metadata.fromJson(alias.getMetadata))
+ } else {
+ None
+ }
+ Alias(transformExpression(alias.getExpr), alias.getName(0))(explicitMetadata = md)
+ } else {
+ if (alias.hasMetadata) {
+ throw new InvalidPlanInput(
+ "Alias expressions with more than 1 identifier must not use optional metadata.")
+ }
+ MultiAlias(transformExpression(alias.getExpr), alias.getNameList.asScala.toSeq)
+ }
}
private def transformExpressionString(expr: proto.Expression.ExpressionString): Expression = {
diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 71633830b56..404581445d0 100644
--- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.connect.dsl.commands._
import org.apache.spark.sql.connect.dsl.expressions._
import org.apache.spark.sql.connect.dsl.plans._
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
+import org.apache.spark.sql.types.{IntegerType, MapType, Metadata, StringType, StructField, StructType}
/**
* This suite is based on connect DSL and test that given same dataframe operations, whether
@@ -50,6 +50,9 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
createLocalRelationProto(
Seq(AttributeReference("id", IntegerType)(), AttributeReference("name", StringType)()))
+ lazy val connectTestRelationMap =
+ createLocalRelationProto(Seq(AttributeReference("id", MapType(StringType, StringType))()))
+
lazy val sparkTestRelation: DataFrame =
spark.createDataFrame(
new java.util.ArrayList[Row](),
@@ -60,6 +63,11 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
new java.util.ArrayList[Row](),
StructType(Seq(StructField("id", IntegerType), StructField("name", StringType))))
+ lazy val sparkTestRelationMap: DataFrame =
+ spark.createDataFrame(
+ new java.util.ArrayList[Row](),
+ StructType(Seq(StructField("id", MapType(StringType, StringType)))))
+
lazy val localRelation = createLocalRelationProto(Seq(AttributeReference("id", IntegerType)()))
test("Basic select") {
@@ -140,10 +148,35 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
comparePlans(connectPlan2, sparkPlan2)
}
- test("column alias") {
+ test("SPARK-40809: column alias") {
+ // Simple Test.
val connectPlan = connectTestRelation.select("id".protoAttr.as("id2"))
val sparkPlan = sparkTestRelation.select(Column("id").alias("id2"))
comparePlans(connectPlan, sparkPlan)
+
+ // Scalar columns with metadata
+ val mdJson = "{\"max\": 99}"
+ comparePlans(
+ connectTestRelation.select("id".protoAttr.as("id2", mdJson)),
+ sparkTestRelation.select(Column("id").as("id2", Metadata.fromJson(mdJson))))
+
+ comparePlans(
+ connectTestRelationMap.select(proto_explode("id".protoAttr).as(Seq("a", "b"))),
+ sparkTestRelationMap.select(explode(Column("id")).as(Seq("a", "b"))))
+
+ // Metadata must only be specified for regular Aliases.
+ assertThrows[InvalidPlanInput] {
+ val attr = proto_explode("id".protoAttr)
+ val alias = proto.Expression.Alias
+ .newBuilder()
+ .setExpr(attr)
+ .addName("a")
+ .addName("b")
+ .setMetadata(mdJson)
+ .build()
+ transform(
+ connectTestRelationMap.select(proto.Expression.newBuilder().setAlias(alias).build()))
+ }
}
test("Aggregate with more than 1 grouping expressions") {
diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py
index 417bc7097de..9f610bf18fe 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -17,6 +17,7 @@
import uuid
from typing import cast, get_args, TYPE_CHECKING, Callable, Any
+import json
import decimal
import datetime
@@ -82,6 +83,71 @@ class Expression(object):
def __str__(self) -> str:
...
+ def alias(self, *alias: str, **kwargs: Any) -> "ColumnAlias":
+ """
+ Returns this column aliased with a new name or names (in the case of expressions that
+ return more than one column, such as explode).
+
+ .. versionadded:: 3.4.0
+
+ Parameters
+ ----------
+ alias : str
+ desired column names (collects all positional arguments passed)
+
+ Other Parameters
+ ----------------
+ metadata: dict
+ a dict of information to be stored in ``metadata`` attribute of the
+ corresponding :class:`StructField <pyspark.sql.types.StructField>` (optional, keyword
+ only argument)
+
+ Returns
+ -------
+ :class:`Column`
+ Column representing whether each element of Column is aliased with new name or names.
+
+ Examples
+ --------
+ >>> df = spark.createDataFrame(
+ ... [(2, "Alice"), (5, "Bob")], ["age", "name"])
+ >>> df.select(df.age.alias("age2")).collect()
+ [Row(age2=2), Row(age2=5)]
+ >>> df.select(df.age.alias("age3", metadata={'max': 99})).schema['age3'].metadata['max']
+ 99
+ """
+ metadata = kwargs.pop("metadata", None)
+ assert not kwargs, "Unexpected kwargs where passed: %s" % kwargs
+ return ColumnAlias(self, list(alias), metadata)
+
+
+class ColumnAlias(Expression):
+ def __init__(self, parent: Expression, alias: list[str], metadata: Any):
+
+ self._alias = alias
+ self._metadata = metadata
+ self._parent = parent
+
+ def to_plan(self, session: "RemoteSparkSession") -> "proto.Expression":
+ if len(self._alias) == 1:
+ exp = proto.Expression()
+ exp.alias.name.append(self._alias[0])
+ exp.alias.expr.CopyFrom(self._parent.to_plan(session))
+
+ if self._metadata:
+ exp.alias.metadata = json.dumps(self._metadata)
+ return exp
+ else:
+ if self._metadata:
+ raise ValueError("metadata can only be provided for a single column")
+ exp = proto.Expression()
+ exp.alias.name.extend(self._alias)
+ exp.alias.expr.CopyFrom(self._parent.to_plan(session))
+ return exp
+
+ def __str__(self) -> str:
+ return f"Alias({self._parent}, ({','.join(self._alias)}))"
+
class LiteralExpression(Expression):
"""A literal expression.
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index e2334b66c68..5dd28c0e6a9 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -120,7 +120,7 @@ class DataFrame(object):
new_frame._plan = plan
return new_frame
- def select(self, *cols: ColumnOrName) -> "DataFrame":
+ def select(self, *cols: "ExpressionOrString") -> "DataFrame":
return DataFrame.withPlan(plan.Project(self._plan, *cols), session=self._session)
def agg(self, exprs: Optional[GroupingFrame.MeasuresType]) -> "DataFrame":
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py
index 89718650571..dca9d2cef47 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -33,7 +33,7 @@ from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto\x1a\x19google/protobuf/any.proto"\xc2\x17\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFu [...]
+ b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19spark/connect/types.proto\x1a\x19google/protobuf/any.proto"\xf0\x17\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFu [...]
)
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -43,7 +43,7 @@ if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001"
_EXPRESSION._serialized_start = 105
- _EXPRESSION._serialized_end = 3115
+ _EXPRESSION._serialized_end = 3161
_EXPRESSION_LITERAL._serialized_start = 613
_EXPRESSION_LITERAL._serialized_end = 2696
_EXPRESSION_LITERAL_VARCHAR._serialized_start = 1923
@@ -75,5 +75,5 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_EXPRESSION_QUALIFIEDATTRIBUTE._serialized_start = 2941
_EXPRESSION_QUALIFIEDATTRIBUTE._serialized_end = 3026
_EXPRESSION_ALIAS._serialized_start = 3028
- _EXPRESSION_ALIAS._serialized_end = 3102
+ _EXPRESSION_ALIAS._serialized_end = 3148
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index ffaf03eade8..ea538b2ebec 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -633,21 +633,37 @@ class Expression(google.protobuf.message.Message):
EXPR_FIELD_NUMBER: builtins.int
NAME_FIELD_NUMBER: builtins.int
+ METADATA_FIELD_NUMBER: builtins.int
@property
def expr(self) -> global___Expression: ...
- name: builtins.str
+ @property
+ def name(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: ...
+ metadata: builtins.str
+ """Alias metadata expressed as a JSON map."""
def __init__(
self,
*,
expr: global___Expression | None = ...,
- name: builtins.str = ...,
+ name: collections.abc.Iterable[builtins.str] | None = ...,
+ metadata: builtins.str | None = ...,
) -> None: ...
def HasField(
- self, field_name: typing_extensions.Literal["expr", b"expr"]
+ self,
+ field_name: typing_extensions.Literal[
+ "_metadata", b"_metadata", "expr", b"expr", "metadata", b"metadata"
+ ],
) -> builtins.bool: ...
def ClearField(
- self, field_name: typing_extensions.Literal["expr", b"expr", "name", b"name"]
+ self,
+ field_name: typing_extensions.Literal[
+ "_metadata", b"_metadata", "expr", b"expr", "metadata", b"metadata", "name", b"name"
+ ],
) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_metadata", b"_metadata"]
+ ) -> typing_extensions.Literal["metadata"] | None: ...
LITERAL_FIELD_NUMBER: builtins.int
UNRESOLVED_ATTRIBUTE_FIELD_NUMBER: builtins.int
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index d31e472faec..a49829cc085 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -33,7 +33,7 @@ from pyspark.sql.types import StructType, StructField, LongType, StringType
if have_pandas:
from pyspark.sql.connect.client import RemoteSparkSession, ChannelBuilder
from pyspark.sql.connect.function_builder import udf
- from pyspark.sql.connect.functions import lit
+ from pyspark.sql.connect.functions import lit, col
from pyspark.sql.dataframe import DataFrame
from pyspark.testing.connectutils import should_test_connect, connect_requirement_message
from pyspark.testing.pandasutils import PandasOnSparkTestCase
@@ -291,6 +291,20 @@ class SparkConnectTests(SparkConnectSQLTestCase):
actualResult = pandasResult.values.tolist()
self.assertEqual(len(expectResult), len(actualResult))
+ def test_alias(self) -> None:
+ """Testing supported and unsupported alias"""
+ col0 = (
+ self.connect.range(1, 10)
+ .select(col("id").alias("name", metadata={"max": 99}))
+ .schema()
+ .names[0]
+ )
+ self.assertEqual("name", col0)
+
+ with self.assertRaises(grpc.RpcError) as exc:
+ self.connect.range(1, 10).select(col("id").alias("this", "is", "not")).collect()
+ self.assertIn("Buffer(this, is, not)", str(exc.exception))
+
class ChannelBuilderTests(ReusedPySparkTestCase):
def test_invalid_connection_strings(self):
diff --git a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
index 59e3c97679e..99b63482a24 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column_expressions.py
@@ -134,6 +134,16 @@ class SparkConnectColumnExpressionSuite(PlanOnlyTestFixture):
lit_list_plan = fun.lit([fun.lit(10), fun.lit("str")]).to_plan(None)
self.assertIsNotNone(lit_list_plan)
+ def test_column_alias(self) -> None:
+ # SPARK-40809: Support for Column Aliases
+ col0 = fun.col("a").alias("martin")
+ self.assertEqual("Alias(Column(a), (martin))", str(col0))
+
+ col0 = fun.col("a").alias("martin", metadata={"pii": True})
+ plan = col0.to_plan(self.session)
+ self.assertIsNotNone(plan)
+ self.assertEqual(plan.alias.metadata, '{"pii": true}')
+
def test_column_expressions(self):
"""Test a more complex combination of expressions and their translation into
the protobuf structure."""
diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py
index b7c49a6df54..f98a67b9964 100644
--- a/python/pyspark/testing/connectutils.py
+++ b/python/pyspark/testing/connectutils.py
@@ -18,6 +18,7 @@ import os
from typing import Any, Dict, Optional
import functools
import unittest
+
from pyspark.testing.sqlutils import have_pandas
if have_pandas:
@@ -25,6 +26,7 @@ if have_pandas:
from pyspark.sql.connect.plan import Read, Range, SQL
from pyspark.testing.utils import search_jar
from pyspark.sql.connect.plan import LogicalPlan
+ from pyspark.sql.connect.client import RemoteSparkSession
connect_jar = search_jar("connector/connect", "spark-connect-assembly-", "spark-connect")
else:
@@ -67,6 +69,7 @@ class MockRemoteSession:
class PlanOnlyTestFixture(unittest.TestCase):
connect: "MockRemoteSession"
+ session: RemoteSparkSession
@classmethod
def _read_table(cls, table_name: str) -> "DataFrame":
@@ -99,6 +102,7 @@ class PlanOnlyTestFixture(unittest.TestCase):
@classmethod
def setUpClass(cls: Any) -> None:
cls.connect = MockRemoteSession()
+ cls.session = RemoteSparkSession()
cls.tbl_name = "test_connect_plan_only_table_1"
cls.connect.set_hook("register_udf", cls._udf_mock)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org