You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by xi...@apache.org on 2023/02/09 10:19:12 UTC

[spark] branch branch-3.4 updated: [SPARK-42210][CONNECT][PYTHON] Standardize registered pickled Python UDFs

This is an automated email from the ASF dual-hosted git repository.

xinrong pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 9e2fc6448e7 [SPARK-42210][CONNECT][PYTHON] Standardize registered pickled Python UDFs
9e2fc6448e7 is described below

commit 9e2fc6448e71c00b831d34e289278e6418d6d59f
Author: Xinrong Meng <xi...@apache.org>
AuthorDate: Thu Feb 9 18:18:08 2023 +0800

    [SPARK-42210][CONNECT][PYTHON] Standardize registered pickled Python UDFs
    
    ### What changes were proposed in this pull request?
    Standardize registered pickled Python UDFs, specifically, implement `spark.udf.register()`.
    
    ### Why are the changes needed?
    To reach parity with vanilla PySpark.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. `spark.udf.register()` is added as shown below:
    
    ```py
    >>> spark.udf
    <pyspark.sql.connect.udf.UDFRegistration object at 0x7fbca0077dc0>
    >>> f = spark.udf.register("f", lambda x: x+1, "int")
    >>> f
    <function <lambda> at 0x7fbc905e5e50>
    >>> spark.sql("SELECT f(id) FROM range(2)").collect()
    [Row(f(id)=1), Row(f(id)=2)]
    ```
    
    ### How was this patch tested?
    Unit tests.
    
    Closes #39860 from xinrong-meng/connect_registered_udf.
    
    Lead-authored-by: Xinrong Meng <xi...@apache.org>
    Co-authored-by: Xinrong Meng <xi...@gmail.com>
    Signed-off-by: Xinrong Meng <xi...@apache.org>
    (cherry picked from commit e7eb836376b72ae58b741e87d40f2d42c9914537)
    Signed-off-by: Xinrong Meng <xi...@apache.org>
---
 .../src/main/protobuf/spark/connect/commands.proto |  1 +
 .../sql/connect/planner/SparkConnectPlanner.scala  | 33 ++++++++++++
 python/pyspark/sql/connect/client.py               | 59 ++++++++++++++++++++++
 python/pyspark/sql/connect/expressions.py          |  7 +++
 python/pyspark/sql/connect/proto/commands_pb2.py   | 40 +++++++--------
 python/pyspark/sql/connect/proto/commands_pb2.pyi  | 17 ++++++-
 python/pyspark/sql/connect/session.py              |  9 +++-
 python/pyspark/sql/connect/udf.py                  | 58 ++++++++++++++++++++-
 python/pyspark/sql/session.py                      |  6 +--
 .../sql/tests/connect/test_connect_basic.py        |  1 -
 .../pyspark/sql/tests/connect/test_parity_udf.py   | 17 ++++---
 python/pyspark/sql/udf.py                          |  3 ++
 12 files changed, 216 insertions(+), 35 deletions(-)

diff --git a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
index 05c91d2c992..73218697577 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
@@ -31,6 +31,7 @@ option java_package = "org.apache.spark.connect.proto";
 // produce a relational result.
 message Command {
   oneof command_type {
+    CommonInlineUserDefinedFunction register_function = 1;
     WriteOperation write_operation = 2;
     CreateDataFrameViewCommand create_dataframe_view = 3;
     WriteOperationV2 write_operation_v2 = 4;
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index c8a0860b871..3bf5d2b1d30 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -44,6 +44,7 @@ import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.execution.arrow.ArrowConverters
 import org.apache.spark.sql.execution.command.CreateViewCommand
+import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
 import org.apache.spark.sql.functions.{col, expr}
 import org.apache.spark.sql.internal.CatalogImpl
 import org.apache.spark.sql.types._
@@ -1399,6 +1400,8 @@ class SparkConnectPlanner(val session: SparkSession) {
 
   def process(command: proto.Command): Unit = {
     command.getCommandTypeCase match {
+      case proto.Command.CommandTypeCase.REGISTER_FUNCTION =>
+        handleRegisterUserDefinedFunction(command.getRegisterFunction)
       case proto.Command.CommandTypeCase.WRITE_OPERATION =>
         handleWriteOperation(command.getWriteOperation)
       case proto.Command.CommandTypeCase.CREATE_DATAFRAME_VIEW =>
@@ -1411,6 +1414,36 @@ class SparkConnectPlanner(val session: SparkSession) {
     }
   }
 
+  private def handleRegisterUserDefinedFunction(
+      fun: proto.CommonInlineUserDefinedFunction): Unit = {
+    fun.getFunctionCase match {
+      case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
+        handleRegisterPythonUDF(fun)
+      case _ =>
+        throw InvalidPlanInput(
+          s"Function with ID: ${fun.getFunctionCase.getNumber} is not supported")
+    }
+  }
+
+  private def handleRegisterPythonUDF(fun: proto.CommonInlineUserDefinedFunction): Unit = {
+    val udf = fun.getPythonUdf
+    val function = transformPythonFunction(udf)
+    val udpf = UserDefinedPythonFunction(
+      name = fun.getFunctionName,
+      func = function,
+      dataType = DataType.parseTypeWithFallback(
+        schema = udf.getOutputType,
+        parser = DataType.fromDDL,
+        fallbackParser = DataType.fromJson) match {
+        case s: DataType => s
+        case other => throw InvalidPlanInput(s"Invalid return type $other")
+      },
+      pythonEvalType = udf.getEvalType,
+      udfDeterministic = fun.getDeterministic)
+
+    session.udf.registerPython(fun.getFunctionName, udpf)
+  }
+
   private def handleCommandPlugin(extension: ProtoAny): Unit = {
     SparkConnectPluginRegistry.commandRegistry
       // Lazily traverse the collection.
diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py
index 8cf5fa50693..903981a015b 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -30,6 +30,7 @@ import time
 import urllib.parse
 import uuid
 import json
+import sys
 from types import TracebackType
 from typing import (
     Iterable,
@@ -67,11 +68,18 @@ from pyspark.errors.exceptions.connect import (
     TempTableAlreadyExistsException,
     IllegalArgumentException,
 )
+from pyspark.sql.connect.expressions import (
+    PythonUDF,
+    CommonInlineUserDefinedFunction,
+)
 from pyspark.sql.types import (
     DataType,
     StructType,
     StructField,
 )
+from pyspark.sql.utils import is_remote
+from pyspark.serializers import CloudPickleSerializer
+from pyspark.rdd import PythonEvalType
 
 
 def _configure_logging() -> logging.Logger:
@@ -428,6 +436,57 @@ class SparkConnectClient(object):
         self._stub = grpc_lib.SparkConnectServiceStub(self._channel)
         # Configure logging for the SparkConnect client.
 
+    def register_udf(
+        self,
+        function: Any,
+        return_type: Union[str, DataType],
+        name: Optional[str] = None,
+        eval_type: int = PythonEvalType.SQL_BATCHED_UDF,
+        deterministic: bool = True,
+    ) -> str:
+        """Create a temporary UDF in the session catalog on the other side. We generate a
+        temporary name for it."""
+
+        from pyspark.sql import SparkSession as PySparkSession
+
+        if name is None:
+            name = f"fun_{uuid.uuid4().hex}"
+
+        # convert str return_type to DataType
+        if isinstance(return_type, str):
+
+            assert is_remote()
+            return_type_schema = (  # a workaround to parse the DataType from DDL strings
+                PySparkSession.builder.getOrCreate()
+                .createDataFrame(data=[], schema=return_type)
+                .schema
+            )
+            assert len(return_type_schema.fields) == 1, "returnType should be singular"
+            return_type = return_type_schema.fields[0].dataType
+
+        # construct a PythonUDF
+        py_udf = PythonUDF(
+            output_type=return_type.json(),
+            eval_type=eval_type,
+            command=CloudPickleSerializer().dumps((function, return_type)),
+            python_ver="%d.%d" % sys.version_info[:2],
+        )
+
+        # construct a CommonInlineUserDefinedFunction
+        fun = CommonInlineUserDefinedFunction(
+            function_name=name,
+            deterministic=deterministic,
+            arguments=[],
+            function=py_udf,
+        ).to_command(self)
+
+        # construct the request
+        req = self._execute_plan_request_with_metadata()
+        req.plan.command.register_function.CopyFrom(fun)
+
+        self._execute(req)
+        return name
+
     def _build_metrics(self, metrics: "pb2.ExecutePlanResponse.Metrics") -> List[PlanMetrics]:
         return [
             PlanMetrics(
diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py
index dcd7c5ebba6..28b796496ec 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -542,6 +542,13 @@ class CommonInlineUserDefinedFunction(Expression):
         )
         return expr
 
+    def to_command(self, session: "SparkConnectClient") -> "proto.CommonInlineUserDefinedFunction":
+        expr = proto.CommonInlineUserDefinedFunction()
+        expr.function_name = self._function_name
+        expr.deterministic = self._deterministic
+        expr.python_udf.CopyFrom(self._function.to_plan(session))
+        return expr
+
     def __repr__(self) -> str:
         return (
             f"{self._function_name}({', '.join([str(arg) for arg in self._arguments])}), "
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py b/python/pyspark/sql/connect/proto/commands_pb2.py
index c9a51b04bb6..f7e9260212e 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.py
+++ b/python/pyspark/sql/connect/proto/commands_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xcc\x02\n\x07\x43ommand\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x18\x03 \x01(\x0b\x32).spark.connect.CreateDataFrameViewCommandH\x00R\x13\x63reateDataframeView\x12O\n\x12write_operation_v2\ [...]
+    b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xab\x03\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x1 [...]
 )
 
 
@@ -147,23 +147,23 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._options = None
     _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_options = b"8\001"
     _COMMAND._serialized_start = 166
-    _COMMAND._serialized_end = 498
-    _CREATEDATAFRAMEVIEWCOMMAND._serialized_start = 501
-    _CREATEDATAFRAMEVIEWCOMMAND._serialized_end = 651
-    _WRITEOPERATION._serialized_start = 654
-    _WRITEOPERATION._serialized_end = 1396
-    _WRITEOPERATION_OPTIONSENTRY._serialized_start = 1092
-    _WRITEOPERATION_OPTIONSENTRY._serialized_end = 1150
-    _WRITEOPERATION_BUCKETBY._serialized_start = 1152
-    _WRITEOPERATION_BUCKETBY._serialized_end = 1243
-    _WRITEOPERATION_SAVEMODE._serialized_start = 1246
-    _WRITEOPERATION_SAVEMODE._serialized_end = 1383
-    _WRITEOPERATIONV2._serialized_start = 1399
-    _WRITEOPERATIONV2._serialized_end = 2194
-    _WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 1092
-    _WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 1150
-    _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 1966
-    _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2032
-    _WRITEOPERATIONV2_MODE._serialized_start = 2035
-    _WRITEOPERATIONV2_MODE._serialized_end = 2194
+    _COMMAND._serialized_end = 593
+    _CREATEDATAFRAMEVIEWCOMMAND._serialized_start = 596
+    _CREATEDATAFRAMEVIEWCOMMAND._serialized_end = 746
+    _WRITEOPERATION._serialized_start = 749
+    _WRITEOPERATION._serialized_end = 1491
+    _WRITEOPERATION_OPTIONSENTRY._serialized_start = 1187
+    _WRITEOPERATION_OPTIONSENTRY._serialized_end = 1245
+    _WRITEOPERATION_BUCKETBY._serialized_start = 1247
+    _WRITEOPERATION_BUCKETBY._serialized_end = 1338
+    _WRITEOPERATION_SAVEMODE._serialized_start = 1341
+    _WRITEOPERATION_SAVEMODE._serialized_end = 1478
+    _WRITEOPERATIONV2._serialized_start = 1494
+    _WRITEOPERATIONV2._serialized_end = 2289
+    _WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 1187
+    _WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 1245
+    _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 2061
+    _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2127
+    _WRITEOPERATIONV2_MODE._serialized_start = 2130
+    _WRITEOPERATIONV2_MODE._serialized_end = 2289
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.pyi b/python/pyspark/sql/connect/proto/commands_pb2.pyi
index 8a1f2ffb122..4bdf1f1ed4e 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/commands_pb2.pyi
@@ -59,11 +59,16 @@ class Command(google.protobuf.message.Message):
 
     DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
+    REGISTER_FUNCTION_FIELD_NUMBER: builtins.int
     WRITE_OPERATION_FIELD_NUMBER: builtins.int
     CREATE_DATAFRAME_VIEW_FIELD_NUMBER: builtins.int
     WRITE_OPERATION_V2_FIELD_NUMBER: builtins.int
     EXTENSION_FIELD_NUMBER: builtins.int
     @property
+    def register_function(
+        self,
+    ) -> pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction: ...
+    @property
     def write_operation(self) -> global___WriteOperation: ...
     @property
     def create_dataframe_view(self) -> global___CreateDataFrameViewCommand: ...
@@ -77,6 +82,8 @@ class Command(google.protobuf.message.Message):
     def __init__(
         self,
         *,
+        register_function: pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction
+        | None = ...,
         write_operation: global___WriteOperation | None = ...,
         create_dataframe_view: global___CreateDataFrameViewCommand | None = ...,
         write_operation_v2: global___WriteOperationV2 | None = ...,
@@ -91,6 +98,8 @@ class Command(google.protobuf.message.Message):
             b"create_dataframe_view",
             "extension",
             b"extension",
+            "register_function",
+            b"register_function",
             "write_operation",
             b"write_operation",
             "write_operation_v2",
@@ -106,6 +115,8 @@ class Command(google.protobuf.message.Message):
             b"create_dataframe_view",
             "extension",
             b"extension",
+            "register_function",
+            b"register_function",
             "write_operation",
             b"write_operation",
             "write_operation_v2",
@@ -115,7 +126,11 @@ class Command(google.protobuf.message.Message):
     def WhichOneof(
         self, oneof_group: typing_extensions.Literal["command_type", b"command_type"]
     ) -> typing_extensions.Literal[
-        "write_operation", "create_dataframe_view", "write_operation_v2", "extension"
+        "register_function",
+        "write_operation",
+        "create_dataframe_view",
+        "write_operation_v2",
+        "extension",
     ] | None: ...
 
 global___Command = Command
diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py
index 898baa45b03..3c44d06bb1c 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -68,6 +68,7 @@ from pyspark.sql.utils import to_str
 if TYPE_CHECKING:
     from pyspark.sql.connect._typing import OptionalPrimitiveType
     from pyspark.sql.connect.catalog import Catalog
+    from pyspark.sql.connect.udf import UDFRegistration
 
 
 class SparkSession:
@@ -436,8 +437,12 @@ class SparkSession:
         raise NotImplementedError("readStream() is not implemented.")
 
     @property
-    def udf(self) -> Any:
-        raise NotImplementedError("udf() is not implemented.")
+    def udf(self) -> "UDFRegistration":
+        from pyspark.sql.connect.udf import UDFRegistration
+
+        return UDFRegistration(self)
+
+    udf.__doc__ = PySparkSession.udf.__doc__
 
     @property
     def version(self) -> str:
diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py
index 6571cf76929..573d8f582e2 100644
--- a/python/pyspark/sql/connect/udf.py
+++ b/python/pyspark/sql/connect/udf.py
@@ -23,8 +23,9 @@ check_dependencies(__name__, __file__)
 
 import sys
 import functools
-from typing import Callable, Any, TYPE_CHECKING, Optional
+from typing import cast, Callable, Any, TYPE_CHECKING, Optional, Union
 
+from pyspark.rdd import PythonEvalType
 from pyspark.serializers import CloudPickleSerializer
 from pyspark.sql.connect.expressions import (
     ColumnReference,
@@ -33,6 +34,7 @@ from pyspark.sql.connect.expressions import (
 )
 from pyspark.sql.connect.column import Column
 from pyspark.sql.types import DataType, StringType
+from pyspark.sql.udf import UDFRegistration as PySparkUDFRegistration
 from pyspark.sql.utils import is_remote
 
 
@@ -42,6 +44,7 @@ if TYPE_CHECKING:
         DataTypeOrString,
         UserDefinedFunctionLike,
     )
+    from pyspark.sql.connect.session import SparkSession
     from pyspark.sql.types import StringType
 
 
@@ -75,7 +78,7 @@ class UserDefinedFunction:
         func: Callable[..., Any],
         returnType: "DataTypeOrString" = StringType(),
         name: Optional[str] = None,
-        evalType: int = 100,
+        evalType: int = PythonEvalType.SQL_BATCHED_UDF,
         deterministic: bool = True,
     ):
         if not callable(func):
@@ -187,3 +190,54 @@ class UserDefinedFunction:
         """
         self.deterministic = False
         return self
+
+
+class UDFRegistration:
+    """
+    Wrapper for user-defined function registration.
+    """
+
+    def __init__(self, sparkSession: "SparkSession"):
+        self.sparkSession = sparkSession
+
+    def register(
+        self,
+        name: str,
+        f: Union[Callable[..., Any], "UserDefinedFunctionLike"],
+        returnType: Optional["DataTypeOrString"] = None,
+    ) -> "UserDefinedFunctionLike":
+        # This is to check whether the input function is from a user-defined function or
+        # Python function.
+        if hasattr(f, "asNondeterministic"):
+            if returnType is not None:
+                raise TypeError(
+                    "Invalid return type: data type can not be specified when f is"
+                    "a user-defined function, but got %s." % returnType
+                )
+            f = cast("UserDefinedFunctionLike", f)
+            if f.evalType not in [
+                PythonEvalType.SQL_BATCHED_UDF,
+                PythonEvalType.SQL_SCALAR_PANDAS_UDF,
+                PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
+                PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
+            ]:
+                raise ValueError(
+                    "Invalid f: f must be SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF, "
+                    "SQL_SCALAR_PANDAS_ITER_UDF or SQL_GROUPED_AGG_PANDAS_UDF."
+                )
+            return_udf = f
+            self.sparkSession._client.register_udf(
+                f, f.returnType, name, f.evalType, f.deterministic
+            )
+        else:
+            if returnType is None:
+                returnType = StringType()
+            return_udf = _create_udf(
+                f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF, name=name
+            )
+
+            self.sparkSession._client.register_udf(f, returnType, name)
+
+        return return_udf
+
+    register.__doc__ = PySparkUDFRegistration.register.__doc__
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 942e1da95c8..38c93b2d0ac 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -709,15 +709,15 @@ class SparkSession(SparkConversionMixin):
 
         .. versionadded:: 2.0.0
 
+        .. versionchanged:: 3.4.0
+            Support Spark Connect.
+
         Returns
         -------
         :class:`UDFRegistration`
 
         Examples
         --------
-        >>> spark.udf
-        <pyspark.sql.udf.UDFRegistration object ...>
-
         Register a Python UDF, and use it in SQL.
 
         >>> strlen = spark.udf.register("strlen", lambda x: len(x))
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index eebfaaa39d8..b8e2c7b151a 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -2734,7 +2734,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             "sparkContext",
             "streams",
             "readStream",
-            "udf",
             "version",
         ):
             with self.assertRaises(NotImplementedError):
diff --git a/python/pyspark/sql/tests/connect/test_parity_udf.py b/python/pyspark/sql/tests/connect/test_parity_udf.py
index 8d4bb69bf16..5fe1dee7fe8 100644
--- a/python/pyspark/sql/tests/connect/test_parity_udf.py
+++ b/python/pyspark/sql/tests/connect/test_parity_udf.py
@@ -19,7 +19,7 @@ import unittest
 
 from pyspark.testing.connectutils import should_test_connect
 
-if should_test_connect:  # test_udf_with_partial_function
+if should_test_connect:
     from pyspark import sql
     from pyspark.sql.connect.udf import UserDefinedFunction
 
@@ -27,6 +27,7 @@ if should_test_connect:  # test_udf_with_partial_function
 
 from pyspark.sql.tests.test_udf import BaseUDFTestsMixin
 from pyspark.testing.connectutils import ReusedConnectTestCase
+from pyspark.sql.types import IntegerType
 
 
 class UDFParityTests(BaseUDFTestsMixin, ReusedConnectTestCase):
@@ -149,11 +150,6 @@ class UDFParityTests(BaseUDFTestsMixin, ReusedConnectTestCase):
     def test_non_existed_udf(self):
         super().test_non_existed_udf()
 
-    # TODO(SPARK-42210): implement `spark.udf`
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_udf_registration_returns_udf(self):
-        super().test_udf_registration_returns_udf()
-
     # TODO(SPARK-42210): implement `spark.udf`
     @unittest.skip("Fails in Spark Connect, should enable.")
     def test_register_java_function(self):
@@ -179,6 +175,15 @@ class UDFParityTests(BaseUDFTestsMixin, ReusedConnectTestCase):
     def test_udf_in_subquery(self):
         super().test_udf_in_subquery()
 
+    def test_udf_registration_returns_udf(self):
+        df = self.spark.range(10)
+        add_three = self.spark.udf.register("add_three", lambda x: x + 3, IntegerType())
+
+        self.assertListEqual(
+            df.selectExpr("add_three(id) AS plus_three").collect(),
+            df.select(add_three("id").alias("plus_three")).collect(),
+        )
+
 
 if __name__ == "__main__":
     import unittest
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 79ae456b1f7..9f8e3e46977 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -477,6 +477,9 @@ class UDFRegistration:
 
         .. versionadded:: 1.3.1
 
+        .. versionchanged:: 3.4.0
+            Support Spark Connect.
+
         Parameters
         ----------
         name : str,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org