You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by xi...@apache.org on 2023/02/09 10:19:12 UTC
[spark] branch branch-3.4 updated: [SPARK-42210][CONNECT][PYTHON] Standardize registered pickled Python UDFs
This is an automated email from the ASF dual-hosted git repository.
xinrong pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new 9e2fc6448e7 [SPARK-42210][CONNECT][PYTHON] Standardize registered pickled Python UDFs
9e2fc6448e7 is described below
commit 9e2fc6448e71c00b831d34e289278e6418d6d59f
Author: Xinrong Meng <xi...@apache.org>
AuthorDate: Thu Feb 9 18:18:08 2023 +0800
[SPARK-42210][CONNECT][PYTHON] Standardize registered pickled Python UDFs
### What changes were proposed in this pull request?
Standardize registered pickled Python UDFs, specifically, implement `spark.udf.register()`.
### Why are the changes needed?
To reach parity with vanilla PySpark.
### Does this PR introduce _any_ user-facing change?
Yes. `spark.udf.register()` is added as shown below:
```py
>>> spark.udf
<pyspark.sql.connect.udf.UDFRegistration object at 0x7fbca0077dc0>
>>> f = spark.udf.register("f", lambda x: x+1, "int")
>>> f
<function <lambda> at 0x7fbc905e5e50>
>>> spark.sql("SELECT f(id) FROM range(2)").collect()
[Row(f(id)=1), Row(f(id)=2)]
```
### How was this patch tested?
Unit tests.
Closes #39860 from xinrong-meng/connect_registered_udf.
Lead-authored-by: Xinrong Meng <xi...@apache.org>
Co-authored-by: Xinrong Meng <xi...@gmail.com>
Signed-off-by: Xinrong Meng <xi...@apache.org>
(cherry picked from commit e7eb836376b72ae58b741e87d40f2d42c9914537)
Signed-off-by: Xinrong Meng <xi...@apache.org>
---
.../src/main/protobuf/spark/connect/commands.proto | 1 +
.../sql/connect/planner/SparkConnectPlanner.scala | 33 ++++++++++++
python/pyspark/sql/connect/client.py | 59 ++++++++++++++++++++++
python/pyspark/sql/connect/expressions.py | 7 +++
python/pyspark/sql/connect/proto/commands_pb2.py | 40 +++++++--------
python/pyspark/sql/connect/proto/commands_pb2.pyi | 17 ++++++-
python/pyspark/sql/connect/session.py | 9 +++-
python/pyspark/sql/connect/udf.py | 58 ++++++++++++++++++++-
python/pyspark/sql/session.py | 6 +--
.../sql/tests/connect/test_connect_basic.py | 1 -
.../pyspark/sql/tests/connect/test_parity_udf.py | 17 ++++---
python/pyspark/sql/udf.py | 3 ++
12 files changed, 216 insertions(+), 35 deletions(-)
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
index 05c91d2c992..73218697577 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
@@ -31,6 +31,7 @@ option java_package = "org.apache.spark.connect.proto";
// produce a relational result.
message Command {
oneof command_type {
+ CommonInlineUserDefinedFunction register_function = 1;
WriteOperation write_operation = 2;
CreateDataFrameViewCommand create_dataframe_view = 3;
WriteOperationV2 write_operation_v2 = 4;
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index c8a0860b871..3bf5d2b1d30 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -44,6 +44,7 @@ import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.execution.command.CreateViewCommand
+import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
import org.apache.spark.sql.functions.{col, expr}
import org.apache.spark.sql.internal.CatalogImpl
import org.apache.spark.sql.types._
@@ -1399,6 +1400,8 @@ class SparkConnectPlanner(val session: SparkSession) {
def process(command: proto.Command): Unit = {
command.getCommandTypeCase match {
+ case proto.Command.CommandTypeCase.REGISTER_FUNCTION =>
+ handleRegisterUserDefinedFunction(command.getRegisterFunction)
case proto.Command.CommandTypeCase.WRITE_OPERATION =>
handleWriteOperation(command.getWriteOperation)
case proto.Command.CommandTypeCase.CREATE_DATAFRAME_VIEW =>
@@ -1411,6 +1414,36 @@ class SparkConnectPlanner(val session: SparkSession) {
}
}
+ private def handleRegisterUserDefinedFunction(
+ fun: proto.CommonInlineUserDefinedFunction): Unit = {
+ fun.getFunctionCase match {
+ case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
+ handleRegisterPythonUDF(fun)
+ case _ =>
+ throw InvalidPlanInput(
+ s"Function with ID: ${fun.getFunctionCase.getNumber} is not supported")
+ }
+ }
+
+ private def handleRegisterPythonUDF(fun: proto.CommonInlineUserDefinedFunction): Unit = {
+ val udf = fun.getPythonUdf
+ val function = transformPythonFunction(udf)
+ val udpf = UserDefinedPythonFunction(
+ name = fun.getFunctionName,
+ func = function,
+ dataType = DataType.parseTypeWithFallback(
+ schema = udf.getOutputType,
+ parser = DataType.fromDDL,
+ fallbackParser = DataType.fromJson) match {
+ case s: DataType => s
+ case other => throw InvalidPlanInput(s"Invalid return type $other")
+ },
+ pythonEvalType = udf.getEvalType,
+ udfDeterministic = fun.getDeterministic)
+
+ session.udf.registerPython(fun.getFunctionName, udpf)
+ }
+
private def handleCommandPlugin(extension: ProtoAny): Unit = {
SparkConnectPluginRegistry.commandRegistry
// Lazily traverse the collection.
diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py
index 8cf5fa50693..903981a015b 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -30,6 +30,7 @@ import time
import urllib.parse
import uuid
import json
+import sys
from types import TracebackType
from typing import (
Iterable,
@@ -67,11 +68,18 @@ from pyspark.errors.exceptions.connect import (
TempTableAlreadyExistsException,
IllegalArgumentException,
)
+from pyspark.sql.connect.expressions import (
+ PythonUDF,
+ CommonInlineUserDefinedFunction,
+)
from pyspark.sql.types import (
DataType,
StructType,
StructField,
)
+from pyspark.sql.utils import is_remote
+from pyspark.serializers import CloudPickleSerializer
+from pyspark.rdd import PythonEvalType
def _configure_logging() -> logging.Logger:
@@ -428,6 +436,57 @@ class SparkConnectClient(object):
self._stub = grpc_lib.SparkConnectServiceStub(self._channel)
# Configure logging for the SparkConnect client.
+ def register_udf(
+ self,
+ function: Any,
+ return_type: Union[str, DataType],
+ name: Optional[str] = None,
+ eval_type: int = PythonEvalType.SQL_BATCHED_UDF,
+ deterministic: bool = True,
+ ) -> str:
+ """Create a temporary UDF in the session catalog on the other side. We generate a
+ temporary name for it."""
+
+ from pyspark.sql import SparkSession as PySparkSession
+
+ if name is None:
+ name = f"fun_{uuid.uuid4().hex}"
+
+ # convert str return_type to DataType
+ if isinstance(return_type, str):
+
+ assert is_remote()
+ return_type_schema = ( # a workaround to parse the DataType from DDL strings
+ PySparkSession.builder.getOrCreate()
+ .createDataFrame(data=[], schema=return_type)
+ .schema
+ )
+ assert len(return_type_schema.fields) == 1, "returnType should be singular"
+ return_type = return_type_schema.fields[0].dataType
+
+ # construct a PythonUDF
+ py_udf = PythonUDF(
+ output_type=return_type.json(),
+ eval_type=eval_type,
+ command=CloudPickleSerializer().dumps((function, return_type)),
+ python_ver="%d.%d" % sys.version_info[:2],
+ )
+
+ # construct a CommonInlineUserDefinedFunction
+ fun = CommonInlineUserDefinedFunction(
+ function_name=name,
+ deterministic=deterministic,
+ arguments=[],
+ function=py_udf,
+ ).to_command(self)
+
+ # construct the request
+ req = self._execute_plan_request_with_metadata()
+ req.plan.command.register_function.CopyFrom(fun)
+
+ self._execute(req)
+ return name
+
def _build_metrics(self, metrics: "pb2.ExecutePlanResponse.Metrics") -> List[PlanMetrics]:
return [
PlanMetrics(
diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py
index dcd7c5ebba6..28b796496ec 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -542,6 +542,13 @@ class CommonInlineUserDefinedFunction(Expression):
)
return expr
+ def to_command(self, session: "SparkConnectClient") -> "proto.CommonInlineUserDefinedFunction":
+ expr = proto.CommonInlineUserDefinedFunction()
+ expr.function_name = self._function_name
+ expr.deterministic = self._deterministic
+ expr.python_udf.CopyFrom(self._function.to_plan(session))
+ return expr
+
def __repr__(self) -> str:
return (
f"{self._function_name}({', '.join([str(arg) for arg in self._arguments])}), "
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py b/python/pyspark/sql/connect/proto/commands_pb2.py
index c9a51b04bb6..f7e9260212e 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.py
+++ b/python/pyspark/sql/connect/proto/commands_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xcc\x02\n\x07\x43ommand\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x18\x03 \x01(\x0b\x32).spark.connect.CreateDataFrameViewCommandH\x00R\x13\x63reateDataframeView\x12O\n\x12write_operation_v2\ [...]
+ b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xab\x03\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x1 [...]
)
@@ -147,23 +147,23 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_WRITEOPERATIONV2_TABLEPROPERTIESENTRY._options = None
_WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_options = b"8\001"
_COMMAND._serialized_start = 166
- _COMMAND._serialized_end = 498
- _CREATEDATAFRAMEVIEWCOMMAND._serialized_start = 501
- _CREATEDATAFRAMEVIEWCOMMAND._serialized_end = 651
- _WRITEOPERATION._serialized_start = 654
- _WRITEOPERATION._serialized_end = 1396
- _WRITEOPERATION_OPTIONSENTRY._serialized_start = 1092
- _WRITEOPERATION_OPTIONSENTRY._serialized_end = 1150
- _WRITEOPERATION_BUCKETBY._serialized_start = 1152
- _WRITEOPERATION_BUCKETBY._serialized_end = 1243
- _WRITEOPERATION_SAVEMODE._serialized_start = 1246
- _WRITEOPERATION_SAVEMODE._serialized_end = 1383
- _WRITEOPERATIONV2._serialized_start = 1399
- _WRITEOPERATIONV2._serialized_end = 2194
- _WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 1092
- _WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 1150
- _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 1966
- _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2032
- _WRITEOPERATIONV2_MODE._serialized_start = 2035
- _WRITEOPERATIONV2_MODE._serialized_end = 2194
+ _COMMAND._serialized_end = 593
+ _CREATEDATAFRAMEVIEWCOMMAND._serialized_start = 596
+ _CREATEDATAFRAMEVIEWCOMMAND._serialized_end = 746
+ _WRITEOPERATION._serialized_start = 749
+ _WRITEOPERATION._serialized_end = 1491
+ _WRITEOPERATION_OPTIONSENTRY._serialized_start = 1187
+ _WRITEOPERATION_OPTIONSENTRY._serialized_end = 1245
+ _WRITEOPERATION_BUCKETBY._serialized_start = 1247
+ _WRITEOPERATION_BUCKETBY._serialized_end = 1338
+ _WRITEOPERATION_SAVEMODE._serialized_start = 1341
+ _WRITEOPERATION_SAVEMODE._serialized_end = 1478
+ _WRITEOPERATIONV2._serialized_start = 1494
+ _WRITEOPERATIONV2._serialized_end = 2289
+ _WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 1187
+ _WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 1245
+ _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 2061
+ _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2127
+ _WRITEOPERATIONV2_MODE._serialized_start = 2130
+ _WRITEOPERATIONV2_MODE._serialized_end = 2289
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.pyi b/python/pyspark/sql/connect/proto/commands_pb2.pyi
index 8a1f2ffb122..4bdf1f1ed4e 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/commands_pb2.pyi
@@ -59,11 +59,16 @@ class Command(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
+ REGISTER_FUNCTION_FIELD_NUMBER: builtins.int
WRITE_OPERATION_FIELD_NUMBER: builtins.int
CREATE_DATAFRAME_VIEW_FIELD_NUMBER: builtins.int
WRITE_OPERATION_V2_FIELD_NUMBER: builtins.int
EXTENSION_FIELD_NUMBER: builtins.int
@property
+ def register_function(
+ self,
+ ) -> pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction: ...
+ @property
def write_operation(self) -> global___WriteOperation: ...
@property
def create_dataframe_view(self) -> global___CreateDataFrameViewCommand: ...
@@ -77,6 +82,8 @@ class Command(google.protobuf.message.Message):
def __init__(
self,
*,
+ register_function: pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction
+ | None = ...,
write_operation: global___WriteOperation | None = ...,
create_dataframe_view: global___CreateDataFrameViewCommand | None = ...,
write_operation_v2: global___WriteOperationV2 | None = ...,
@@ -91,6 +98,8 @@ class Command(google.protobuf.message.Message):
b"create_dataframe_view",
"extension",
b"extension",
+ "register_function",
+ b"register_function",
"write_operation",
b"write_operation",
"write_operation_v2",
@@ -106,6 +115,8 @@ class Command(google.protobuf.message.Message):
b"create_dataframe_view",
"extension",
b"extension",
+ "register_function",
+ b"register_function",
"write_operation",
b"write_operation",
"write_operation_v2",
@@ -115,7 +126,11 @@ class Command(google.protobuf.message.Message):
def WhichOneof(
self, oneof_group: typing_extensions.Literal["command_type", b"command_type"]
) -> typing_extensions.Literal[
- "write_operation", "create_dataframe_view", "write_operation_v2", "extension"
+ "register_function",
+ "write_operation",
+ "create_dataframe_view",
+ "write_operation_v2",
+ "extension",
] | None: ...
global___Command = Command
diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py
index 898baa45b03..3c44d06bb1c 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -68,6 +68,7 @@ from pyspark.sql.utils import to_str
if TYPE_CHECKING:
from pyspark.sql.connect._typing import OptionalPrimitiveType
from pyspark.sql.connect.catalog import Catalog
+ from pyspark.sql.connect.udf import UDFRegistration
class SparkSession:
@@ -436,8 +437,12 @@ class SparkSession:
raise NotImplementedError("readStream() is not implemented.")
@property
- def udf(self) -> Any:
- raise NotImplementedError("udf() is not implemented.")
+ def udf(self) -> "UDFRegistration":
+ from pyspark.sql.connect.udf import UDFRegistration
+
+ return UDFRegistration(self)
+
+ udf.__doc__ = PySparkSession.udf.__doc__
@property
def version(self) -> str:
diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py
index 6571cf76929..573d8f582e2 100644
--- a/python/pyspark/sql/connect/udf.py
+++ b/python/pyspark/sql/connect/udf.py
@@ -23,8 +23,9 @@ check_dependencies(__name__, __file__)
import sys
import functools
-from typing import Callable, Any, TYPE_CHECKING, Optional
+from typing import cast, Callable, Any, TYPE_CHECKING, Optional, Union
+from pyspark.rdd import PythonEvalType
from pyspark.serializers import CloudPickleSerializer
from pyspark.sql.connect.expressions import (
ColumnReference,
@@ -33,6 +34,7 @@ from pyspark.sql.connect.expressions import (
)
from pyspark.sql.connect.column import Column
from pyspark.sql.types import DataType, StringType
+from pyspark.sql.udf import UDFRegistration as PySparkUDFRegistration
from pyspark.sql.utils import is_remote
@@ -42,6 +44,7 @@ if TYPE_CHECKING:
DataTypeOrString,
UserDefinedFunctionLike,
)
+ from pyspark.sql.connect.session import SparkSession
from pyspark.sql.types import StringType
@@ -75,7 +78,7 @@ class UserDefinedFunction:
func: Callable[..., Any],
returnType: "DataTypeOrString" = StringType(),
name: Optional[str] = None,
- evalType: int = 100,
+ evalType: int = PythonEvalType.SQL_BATCHED_UDF,
deterministic: bool = True,
):
if not callable(func):
@@ -187,3 +190,54 @@ class UserDefinedFunction:
"""
self.deterministic = False
return self
+
+
+class UDFRegistration:
+ """
+ Wrapper for user-defined function registration.
+ """
+
+ def __init__(self, sparkSession: "SparkSession"):
+ self.sparkSession = sparkSession
+
+ def register(
+ self,
+ name: str,
+ f: Union[Callable[..., Any], "UserDefinedFunctionLike"],
+ returnType: Optional["DataTypeOrString"] = None,
+ ) -> "UserDefinedFunctionLike":
+ # This is to check whether the input function is from a user-defined function or
+ # Python function.
+ if hasattr(f, "asNondeterministic"):
+ if returnType is not None:
+ raise TypeError(
+ "Invalid return type: data type can not be specified when f is"
+ "a user-defined function, but got %s." % returnType
+ )
+ f = cast("UserDefinedFunctionLike", f)
+ if f.evalType not in [
+ PythonEvalType.SQL_BATCHED_UDF,
+ PythonEvalType.SQL_SCALAR_PANDAS_UDF,
+ PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
+ PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
+ ]:
+ raise ValueError(
+ "Invalid f: f must be SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF, "
+ "SQL_SCALAR_PANDAS_ITER_UDF or SQL_GROUPED_AGG_PANDAS_UDF."
+ )
+ return_udf = f
+ self.sparkSession._client.register_udf(
+ f, f.returnType, name, f.evalType, f.deterministic
+ )
+ else:
+ if returnType is None:
+ returnType = StringType()
+ return_udf = _create_udf(
+ f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF, name=name
+ )
+
+ self.sparkSession._client.register_udf(f, returnType, name)
+
+ return return_udf
+
+ register.__doc__ = PySparkUDFRegistration.register.__doc__
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 942e1da95c8..38c93b2d0ac 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -709,15 +709,15 @@ class SparkSession(SparkConversionMixin):
.. versionadded:: 2.0.0
+ .. versionchanged:: 3.4.0
+ Support Spark Connect.
+
Returns
-------
:class:`UDFRegistration`
Examples
--------
- >>> spark.udf
- <pyspark.sql.udf.UDFRegistration object ...>
-
Register a Python UDF, and use it in SQL.
>>> strlen = spark.udf.register("strlen", lambda x: len(x))
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index eebfaaa39d8..b8e2c7b151a 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -2734,7 +2734,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
"sparkContext",
"streams",
"readStream",
- "udf",
"version",
):
with self.assertRaises(NotImplementedError):
diff --git a/python/pyspark/sql/tests/connect/test_parity_udf.py b/python/pyspark/sql/tests/connect/test_parity_udf.py
index 8d4bb69bf16..5fe1dee7fe8 100644
--- a/python/pyspark/sql/tests/connect/test_parity_udf.py
+++ b/python/pyspark/sql/tests/connect/test_parity_udf.py
@@ -19,7 +19,7 @@ import unittest
from pyspark.testing.connectutils import should_test_connect
-if should_test_connect: # test_udf_with_partial_function
+if should_test_connect:
from pyspark import sql
from pyspark.sql.connect.udf import UserDefinedFunction
@@ -27,6 +27,7 @@ if should_test_connect: # test_udf_with_partial_function
from pyspark.sql.tests.test_udf import BaseUDFTestsMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.sql.types import IntegerType
class UDFParityTests(BaseUDFTestsMixin, ReusedConnectTestCase):
@@ -149,11 +150,6 @@ class UDFParityTests(BaseUDFTestsMixin, ReusedConnectTestCase):
def test_non_existed_udf(self):
super().test_non_existed_udf()
- # TODO(SPARK-42210): implement `spark.udf`
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_udf_registration_returns_udf(self):
- super().test_udf_registration_returns_udf()
-
# TODO(SPARK-42210): implement `spark.udf`
@unittest.skip("Fails in Spark Connect, should enable.")
def test_register_java_function(self):
@@ -179,6 +175,15 @@ class UDFParityTests(BaseUDFTestsMixin, ReusedConnectTestCase):
def test_udf_in_subquery(self):
super().test_udf_in_subquery()
+ def test_udf_registration_returns_udf(self):
+ df = self.spark.range(10)
+ add_three = self.spark.udf.register("add_three", lambda x: x + 3, IntegerType())
+
+ self.assertListEqual(
+ df.selectExpr("add_three(id) AS plus_three").collect(),
+ df.select(add_three("id").alias("plus_three")).collect(),
+ )
+
if __name__ == "__main__":
import unittest
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 79ae456b1f7..9f8e3e46977 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -477,6 +477,9 @@ class UDFRegistration:
.. versionadded:: 1.3.1
+ .. versionchanged:: 3.4.0
+ Support Spark Connect.
+
Parameters
----------
name : str,
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org