You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by xi...@apache.org on 2023/03/08 06:24:06 UTC

[spark] branch branch-3.4 updated: [SPARK-42643][CONNECT][PYTHON] Register Java (aggregate) user-defined functions

This is an automated email from the ASF dual-hosted git repository.

xinrong pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 0e959a53908 [SPARK-42643][CONNECT][PYTHON] Register Java (aggregate) user-defined functions
0e959a53908 is described below

commit 0e959a539086cda5dd911477ee5568ab540a2249
Author: Xinrong Meng <xi...@apache.org>
AuthorDate: Wed Mar 8 14:23:18 2023 +0800

    [SPARK-42643][CONNECT][PYTHON] Register Java (aggregate) user-defined functions
    
    ### What changes were proposed in this pull request?
    Implement `spark.udf.registerJavaFunction` and `spark.udf.registerJavaUDAF`.
     A new proto `JavaUDF` is introduced.
    
    ### Why are the changes needed?
    Parity with vanilla PySpark.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. `spark.udf.registerJavaFunction` and `spark.udf.registerJavaUDAF` are supported now.
    
    ### How was this patch tested?
    Parity unit tests.
    
    Closes #40244 from xinrong-meng/registerJava.
    
    Authored-by: Xinrong Meng <xi...@apache.org>
    Signed-off-by: Xinrong Meng <xi...@apache.org>
    (cherry picked from commit 92aa08786feaf473330a863d19b0c902b721789e)
    Signed-off-by: Xinrong Meng <xi...@apache.org>
---
 .../main/protobuf/spark/connect/expressions.proto  | 13 ++++-
 .../sql/connect/planner/SparkConnectPlanner.scala  | 21 ++++++++
 python/pyspark/sql/connect/client.py               | 39 ++++++++++++++-
 python/pyspark/sql/connect/expressions.py          | 44 +++++++++++++++--
 .../pyspark/sql/connect/proto/expressions_pb2.py   | 26 +++++++---
 .../pyspark/sql/connect/proto/expressions_pb2.pyi  | 56 +++++++++++++++++++++-
 python/pyspark/sql/connect/udf.py                  | 17 ++++++-
 .../pyspark/sql/tests/connect/test_parity_udf.py   | 30 +++---------
 python/pyspark/sql/udf.py                          |  6 +++
 9 files changed, 212 insertions(+), 40 deletions(-)

diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 6eb769ad27e..0aee3ca13b9 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -312,7 +312,7 @@ message Expression {
 message CommonInlineUserDefinedFunction {
   // (Required) Name of the user-defined function.
   string function_name = 1;
-  // (Required) Indicate if the user-defined function is deterministic.
+  // (Optional) Indicate if the user-defined function is deterministic.
   bool deterministic = 2;
   // (Optional) Function arguments. Empty arguments are allowed.
   repeated Expression arguments = 3;
@@ -320,6 +320,7 @@ message CommonInlineUserDefinedFunction {
   oneof function {
     PythonUDF python_udf = 4;
     ScalarScalaUDF scalar_scala_udf = 5;
+    JavaUDF java_udf = 6;
   }
 }
 
@@ -345,3 +346,13 @@ message ScalarScalaUDF {
   bool nullable = 4;
 }
 
+message JavaUDF {
+  // (Required) Fully qualified name of Java class
+  string class_name = 1;
+
+  // (Optional) Output type of the Java UDF
+  optional string output_type = 2;
+
+  // (Required) Indicate if the Java user-defined function is an aggregate function
+  bool aggregate = 3;
+}
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index d7b3c057d92..3b9443f4e3c 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1552,6 +1552,8 @@ class SparkConnectPlanner(val session: SparkSession) {
     fun.getFunctionCase match {
       case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
         handleRegisterPythonUDF(fun)
+      case proto.CommonInlineUserDefinedFunction.FunctionCase.JAVA_UDF =>
+        handleRegisterJavaUDF(fun)
       case _ =>
         throw InvalidPlanInput(
           s"Function with ID: ${fun.getFunctionCase.getNumber} is not supported")
@@ -1577,6 +1579,25 @@ class SparkConnectPlanner(val session: SparkSession) {
     session.udf.registerPython(fun.getFunctionName, udpf)
   }
 
+  private def handleRegisterJavaUDF(fun: proto.CommonInlineUserDefinedFunction): Unit = {
+    val udf = fun.getJavaUdf
+    val dataType =
+      if (udf.hasOutputType) {
+        DataType.parseTypeWithFallback(
+          schema = udf.getOutputType,
+          parser = DataType.fromDDL,
+          fallbackParser = DataType.fromJson) match {
+          case s: DataType => s
+          case other => throw InvalidPlanInput(s"Invalid return type $other")
+        }
+      } else null
+    if (udf.getAggregate) {
+      session.udf.registerJavaUDAF(fun.getFunctionName, udf.getClassName)
+    } else {
+      session.udf.registerJava(fun.getFunctionName, udf.getClassName, dataType)
+    }
+  }
+
   private def handleCommandPlugin(extension: ProtoAny): Unit = {
     SparkConnectPluginRegistry.commandRegistry
       // Lazily traverse the collection.
diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py
index 8c85f17bb5f..6334036fca4 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -47,6 +47,7 @@ from typing import (
     Callable,
     Generator,
     Type,
+    TYPE_CHECKING,
 )
 
 import pandas as pd
@@ -69,6 +70,7 @@ from pyspark.errors.exceptions.connect import (
 from pyspark.sql.connect.expressions import (
     PythonUDF,
     CommonInlineUserDefinedFunction,
+    JavaUDF,
 )
 from pyspark.sql.connect.types import parse_data_type
 from pyspark.sql.types import (
@@ -80,6 +82,10 @@ from pyspark.serializers import CloudPickleSerializer
 from pyspark.rdd import PythonEvalType
 
 
+if TYPE_CHECKING:
+    from pyspark.sql.connect._typing import DataTypeOrString
+
+
 def _configure_logging() -> logging.Logger:
     """Configure logging for the Spark Connect clients."""
     logger = logging.getLogger(__name__)
@@ -534,7 +540,7 @@ class SparkConnectClient(object):
     def register_udf(
         self,
         function: Any,
-        return_type: Union[str, DataType],
+        return_type: "DataTypeOrString",
         name: Optional[str] = None,
         eval_type: int = PythonEvalType.SQL_BATCHED_UDF,
         deterministic: bool = True,
@@ -561,9 +567,9 @@ class SparkConnectClient(object):
         # construct a CommonInlineUserDefinedFunction
         fun = CommonInlineUserDefinedFunction(
             function_name=name,
-            deterministic=deterministic,
             arguments=[],
             function=py_udf,
+            deterministic=deterministic,
         ).to_plan_udf(self)
 
         # construct the request
@@ -573,6 +579,35 @@ class SparkConnectClient(object):
         self._execute(req)
         return name
 
+    def register_java(
+        self,
+        name: str,
+        javaClassName: str,
+        return_type: Optional["DataTypeOrString"] = None,
+        aggregate: bool = False,
+    ) -> None:
+        # convert str return_type to DataType
+        if isinstance(return_type, str):
+            return_type = parse_data_type(return_type)
+
+        # construct a JavaUDF
+        if return_type is None:
+            java_udf = JavaUDF(class_name=javaClassName, aggregate=aggregate)
+        else:
+            java_udf = JavaUDF(
+                class_name=javaClassName,
+                output_type=return_type.json(),
+            )
+        fun = CommonInlineUserDefinedFunction(
+            function_name=name,
+            function=java_udf,
+        ).to_plan_judf(self)
+        # construct the request
+        req = self._execute_plan_request_with_metadata()
+        req.plan.command.register_function.CopyFrom(fun)
+
+        self._execute(req)
+
     def _build_metrics(self, metrics: "pb2.ExecutePlanResponse.Metrics") -> List[PlanMetrics]:
         return [
             PlanMetrics(
diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py
index 2b1901167c1..0d059740032 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -19,6 +19,7 @@ from pyspark.sql.connect.utils import check_dependencies
 check_dependencies(__name__, __file__)
 
 from typing import (
+    cast,
     TYPE_CHECKING,
     Any,
     Union,
@@ -520,6 +521,31 @@ class PythonUDF:
         )
 
 
+class JavaUDF:
+    """Represents a Java (aggregate) user-defined function."""
+
+    def __init__(
+        self,
+        class_name: str,
+        output_type: Optional[str] = None,
+        aggregate: bool = False,
+    ) -> None:
+        self._class_name = class_name
+        self._output_type = output_type
+        self._aggregate = aggregate
+
+    def to_plan(self, session: "SparkConnectClient") -> proto.JavaUDF:
+        expr = proto.JavaUDF()
+        expr.class_name = self._class_name
+        if self._output_type is not None:
+            expr.output_type = self._output_type
+        expr.aggregate = self._aggregate
+        return expr
+
+    def __repr__(self) -> str:
+        return f"{self._class_name}, {self._output_type}"
+
+
 class CommonInlineUserDefinedFunction(Expression):
     """Represents a user-defined function with an inlined defined function body of any programming
     languages."""
@@ -527,9 +553,9 @@ class CommonInlineUserDefinedFunction(Expression):
     def __init__(
         self,
         function_name: str,
-        deterministic: bool,
-        arguments: Sequence[Expression],
-        function: PythonUDF,
+        function: Union[PythonUDF, JavaUDF],
+        deterministic: bool = False,
+        arguments: Sequence[Expression] = [],
     ):
         self._function_name = function_name
         self._deterministic = deterministic
@@ -545,7 +571,7 @@ class CommonInlineUserDefinedFunction(Expression):
                 [arg.to_plan(session) for arg in self._arguments]
             )
         expr.common_inline_user_defined_function.python_udf.CopyFrom(
-            self._function.to_plan(session)
+            cast(proto.PythonUDF, self._function.to_plan(session))
         )
         return expr
 
@@ -557,7 +583,15 @@ class CommonInlineUserDefinedFunction(Expression):
         expr.deterministic = self._deterministic
         if len(self._arguments) > 0:
             expr.arguments.extend([arg.to_plan(session) for arg in self._arguments])
-        expr.python_udf.CopyFrom(self._function.to_plan(session))
+        expr.python_udf.CopyFrom(cast(proto.PythonUDF, self._function.to_plan(session)))
+        return expr
+
+    def to_plan_judf(
+        self, session: "SparkConnectClient"
+    ) -> "proto.CommonInlineUserDefinedFunction":
+        expr = proto.CommonInlineUserDefinedFunction()
+        expr.function_name = self._function_name
+        expr.java_udf.CopyFrom(cast(proto.JavaUDF, self._function.to_plan(session)))
         return expr
 
     def __repr__(self) -> str:
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py
index d0db2ad56cc..24dd1136480 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xa8\'\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunc [...]
+    b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xa8\'\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunc [...]
 )
 
 
@@ -67,6 +67,7 @@ _COMMONINLINEUSERDEFINEDFUNCTION = DESCRIPTOR.message_types_by_name[
 ]
 _PYTHONUDF = DESCRIPTOR.message_types_by_name["PythonUDF"]
 _SCALARSCALAUDF = DESCRIPTOR.message_types_by_name["ScalarScalaUDF"]
+_JAVAUDF = DESCRIPTOR.message_types_by_name["JavaUDF"]
 _EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE = _EXPRESSION_WINDOW_WINDOWFRAME.enum_types_by_name[
     "FrameType"
 ]
@@ -306,6 +307,17 @@ ScalarScalaUDF = _reflection.GeneratedProtocolMessageType(
 )
 _sym_db.RegisterMessage(ScalarScalaUDF)
 
+JavaUDF = _reflection.GeneratedProtocolMessageType(
+    "JavaUDF",
+    (_message.Message,),
+    {
+        "DESCRIPTOR": _JAVAUDF,
+        "__module__": "spark.connect.expressions_pb2"
+        # @@protoc_insertion_point(class_scope:spark.connect.JavaUDF)
+    },
+)
+_sym_db.RegisterMessage(JavaUDF)
+
 if _descriptor._USE_C_DESCRIPTORS == False:
 
     DESCRIPTOR._options = None
@@ -357,9 +369,11 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5062
     _EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5124
     _COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5140
-    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5451
-    _PYTHONUDF._serialized_start = 5454
-    _PYTHONUDF._serialized_end = 5584
-    _SCALARSCALAUDF._serialized_start = 5587
-    _SCALARSCALAUDF._serialized_end = 5771
+    _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5504
+    _PYTHONUDF._serialized_start = 5507
+    _PYTHONUDF._serialized_end = 5637
+    _SCALARSCALAUDF._serialized_start = 5640
+    _SCALARSCALAUDF._serialized_end = 5824
+    _JAVAUDF._serialized_start = 5826
+    _JAVAUDF._serialized_end = 5950
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 37db24ff91a..19b47c7ab91 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -1171,10 +1171,11 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
     ARGUMENTS_FIELD_NUMBER: builtins.int
     PYTHON_UDF_FIELD_NUMBER: builtins.int
     SCALAR_SCALA_UDF_FIELD_NUMBER: builtins.int
+    JAVA_UDF_FIELD_NUMBER: builtins.int
     function_name: builtins.str
     """(Required) Name of the user-defined function."""
     deterministic: builtins.bool
-    """(Required) Indicate if the user-defined function is deterministic."""
+    """(Optional) Indicate if the user-defined function is deterministic."""
     @property
     def arguments(
         self,
@@ -1184,6 +1185,8 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
     def python_udf(self) -> global___PythonUDF: ...
     @property
     def scalar_scala_udf(self) -> global___ScalarScalaUDF: ...
+    @property
+    def java_udf(self) -> global___JavaUDF: ...
     def __init__(
         self,
         *,
@@ -1192,12 +1195,15 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
         arguments: collections.abc.Iterable[global___Expression] | None = ...,
         python_udf: global___PythonUDF | None = ...,
         scalar_scala_udf: global___ScalarScalaUDF | None = ...,
+        java_udf: global___JavaUDF | None = ...,
     ) -> None: ...
     def HasField(
         self,
         field_name: typing_extensions.Literal[
             "function",
             b"function",
+            "java_udf",
+            b"java_udf",
             "python_udf",
             b"python_udf",
             "scalar_scala_udf",
@@ -1215,6 +1221,8 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
             b"function",
             "function_name",
             b"function_name",
+            "java_udf",
+            b"java_udf",
             "python_udf",
             b"python_udf",
             "scalar_scala_udf",
@@ -1223,7 +1231,7 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
     ) -> None: ...
     def WhichOneof(
         self, oneof_group: typing_extensions.Literal["function", b"function"]
-    ) -> typing_extensions.Literal["python_udf", "scalar_scala_udf"] | None: ...
+    ) -> typing_extensions.Literal["python_udf", "scalar_scala_udf", "java_udf"] | None: ...
 
 global___CommonInlineUserDefinedFunction = CommonInlineUserDefinedFunction
 
@@ -1314,3 +1322,47 @@ class ScalarScalaUDF(google.protobuf.message.Message):
     ) -> None: ...
 
 global___ScalarScalaUDF = ScalarScalaUDF
+
+class JavaUDF(google.protobuf.message.Message):
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    CLASS_NAME_FIELD_NUMBER: builtins.int
+    OUTPUT_TYPE_FIELD_NUMBER: builtins.int
+    AGGREGATE_FIELD_NUMBER: builtins.int
+    class_name: builtins.str
+    """(Required) Fully qualified name of Java class"""
+    output_type: builtins.str
+    """(Optional) Output type of the Java UDF"""
+    aggregate: builtins.bool
+    """(Required) Indicate if the Java user-defined function is an aggregate function"""
+    def __init__(
+        self,
+        *,
+        class_name: builtins.str = ...,
+        output_type: builtins.str | None = ...,
+        aggregate: builtins.bool = ...,
+    ) -> None: ...
+    def HasField(
+        self,
+        field_name: typing_extensions.Literal[
+            "_output_type", b"_output_type", "output_type", b"output_type"
+        ],
+    ) -> builtins.bool: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal[
+            "_output_type",
+            b"_output_type",
+            "aggregate",
+            b"aggregate",
+            "class_name",
+            b"class_name",
+            "output_type",
+            b"output_type",
+        ],
+    ) -> None: ...
+    def WhichOneof(
+        self, oneof_group: typing_extensions.Literal["_output_type", b"_output_type"]
+    ) -> typing_extensions.Literal["output_type"] | None: ...
+
+global___JavaUDF = JavaUDF
diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py
index c6bff4a3caa..03e53cbd89e 100644
--- a/python/pyspark/sql/connect/udf.py
+++ b/python/pyspark/sql/connect/udf.py
@@ -127,9 +127,9 @@ class UserDefinedFunction:
         )
         return CommonInlineUserDefinedFunction(
             function_name=self._name,
+            function=py_udf,
             deterministic=self.deterministic,
             arguments=arg_exprs,
-            function=py_udf,
         )
 
     def __call__(self, *cols: "ColumnOrName") -> Column:
@@ -232,3 +232,18 @@ class UDFRegistration:
         return return_udf
 
     register.__doc__ = PySparkUDFRegistration.register.__doc__
+
+    def registerJavaFunction(
+        self,
+        name: str,
+        javaClassName: str,
+        returnType: Optional["DataTypeOrString"] = None,
+    ) -> None:
+        self.sparkSession._client.register_java(name, javaClassName, returnType)
+
+    registerJavaFunction.__doc__ = PySparkUDFRegistration.registerJavaFunction.__doc__
+
+    def registerJavaUDAF(self, name: str, javaClassName: str) -> None:
+        self.sparkSession._client.register_java(name, javaClassName, aggregate=True)
+
+    registerJavaUDAF.__doc__ = PySparkUDFRegistration.registerJavaUDAF.__doc__
diff --git a/python/pyspark/sql/tests/connect/test_parity_udf.py b/python/pyspark/sql/tests/connect/test_parity_udf.py
index 293f4b0f41a..b38b4c28a25 100644
--- a/python/pyspark/sql/tests/connect/test_parity_udf.py
+++ b/python/pyspark/sql/tests/connect/test_parity_udf.py
@@ -25,6 +25,7 @@ if should_test_connect:
 
     sql.udf.UserDefinedFunction = UserDefinedFunction
 
+from pyspark.errors import AnalysisException
 from pyspark.sql.tests.test_udf import BaseUDFTestsMixin
 from pyspark.testing.connectutils import ReusedConnectTestCase
 from pyspark.sql.types import IntegerType
@@ -103,30 +104,13 @@ class UDFParityTests(BaseUDFTestsMixin, ReusedConnectTestCase):
     def test_udf_registration_return_type_none(self):
         super().test_udf_registration_return_type_none()
 
-    # TODO(SPARK-42210): implement `spark.udf`
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_non_existed_udaf(self):
-        super().test_non_existed_udaf()
-
-    # TODO(SPARK-42210): implement `spark.udf`
-    @unittest.skip("Fails in Spark Connect, should enable.")
     def test_non_existed_udf(self):
-        super().test_non_existed_udf()
-
-    # TODO(SPARK-42210): implement `spark.udf`
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_register_java_function(self):
-        super().test_register_java_function()
-
-    # TODO(SPARK-42210): implement `spark.udf`
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_register_java_udaf(self):
-        super().test_register_java_udaf()
-
-    # TODO(SPARK-42210): implement `spark.udf`
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_udf_in_left_outer_join_condition(self):
-        super().test_udf_in_left_outer_join_condition()
+        spark = self.spark
+        self.assertRaisesRegex(
+            AnalysisException,
+            "Can not load class non_existed_udf",
+            lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf"),
+        )
 
     def test_udf_registration_returns_udf(self):
         df = self.spark.range(10)
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 9f8e3e46977..0b9b082ade3 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -622,6 +622,9 @@ class UDFRegistration:
 
         .. versionadded:: 2.3.0
 
+        .. versionchanged:: 3.4.0
+            Support Spark Connect.
+
         Parameters
         ----------
         name : str
@@ -666,6 +669,9 @@ class UDFRegistration:
 
         .. versionadded:: 2.3.0
 
+        .. versionchanged:: 3.4.0
+            Support Spark Connect.
+
         name : str
             name of the user-defined aggregate function
         javaClassName : str


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org