You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by xi...@apache.org on 2023/03/08 06:23:38 UTC
[spark] branch master updated: [SPARK-42643][CONNECT][PYTHON] Register Java (aggregate) user-defined functions
This is an automated email from the ASF dual-hosted git repository.
xinrong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 92aa08786fe [SPARK-42643][CONNECT][PYTHON] Register Java (aggregate) user-defined functions
92aa08786fe is described below
commit 92aa08786feaf473330a863d19b0c902b721789e
Author: Xinrong Meng <xi...@apache.org>
AuthorDate: Wed Mar 8 14:23:18 2023 +0800
[SPARK-42643][CONNECT][PYTHON] Register Java (aggregate) user-defined functions
### What changes were proposed in this pull request?
Implement `spark.udf.registerJavaFunction` and `spark.udf.registerJavaUDAF`.
A new proto `JavaUDF` is introduced.
### Why are the changes needed?
Parity with vanilla PySpark.
### Does this PR introduce _any_ user-facing change?
Yes. `spark.udf.registerJavaFunction` and `spark.udf.registerJavaUDAF` are supported now.
### How was this patch tested?
Parity unit tests.
Closes #40244 from xinrong-meng/registerJava.
Authored-by: Xinrong Meng <xi...@apache.org>
Signed-off-by: Xinrong Meng <xi...@apache.org>
---
.../main/protobuf/spark/connect/expressions.proto | 13 ++++-
.../sql/connect/planner/SparkConnectPlanner.scala | 21 ++++++++
python/pyspark/sql/connect/client.py | 39 ++++++++++++++-
python/pyspark/sql/connect/expressions.py | 44 +++++++++++++++--
.../pyspark/sql/connect/proto/expressions_pb2.py | 26 +++++++---
.../pyspark/sql/connect/proto/expressions_pb2.pyi | 56 +++++++++++++++++++++-
python/pyspark/sql/connect/udf.py | 17 ++++++-
.../pyspark/sql/tests/connect/test_parity_udf.py | 30 +++---------
python/pyspark/sql/udf.py | 6 +++
9 files changed, 212 insertions(+), 40 deletions(-)
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
index 6eb769ad27e..0aee3ca13b9 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/expressions.proto
@@ -312,7 +312,7 @@ message Expression {
message CommonInlineUserDefinedFunction {
// (Required) Name of the user-defined function.
string function_name = 1;
- // (Required) Indicate if the user-defined function is deterministic.
+ // (Optional) Indicate if the user-defined function is deterministic.
bool deterministic = 2;
// (Optional) Function arguments. Empty arguments are allowed.
repeated Expression arguments = 3;
@@ -320,6 +320,7 @@ message CommonInlineUserDefinedFunction {
oneof function {
PythonUDF python_udf = 4;
ScalarScalaUDF scalar_scala_udf = 5;
+ JavaUDF java_udf = 6;
}
}
@@ -345,3 +346,13 @@ message ScalarScalaUDF {
bool nullable = 4;
}
+message JavaUDF {
+ // (Required) Fully qualified name of Java class
+ string class_name = 1;
+
+ // (Optional) Output type of the Java UDF
+ optional string output_type = 2;
+
+ // (Required) Indicate if the Java user-defined function is an aggregate function
+ bool aggregate = 3;
+}
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index d7b3c057d92..3b9443f4e3c 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1552,6 +1552,8 @@ class SparkConnectPlanner(val session: SparkSession) {
fun.getFunctionCase match {
case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
handleRegisterPythonUDF(fun)
+ case proto.CommonInlineUserDefinedFunction.FunctionCase.JAVA_UDF =>
+ handleRegisterJavaUDF(fun)
case _ =>
throw InvalidPlanInput(
s"Function with ID: ${fun.getFunctionCase.getNumber} is not supported")
@@ -1577,6 +1579,25 @@ class SparkConnectPlanner(val session: SparkSession) {
session.udf.registerPython(fun.getFunctionName, udpf)
}
+ private def handleRegisterJavaUDF(fun: proto.CommonInlineUserDefinedFunction): Unit = {
+ val udf = fun.getJavaUdf
+ val dataType =
+ if (udf.hasOutputType) {
+ DataType.parseTypeWithFallback(
+ schema = udf.getOutputType,
+ parser = DataType.fromDDL,
+ fallbackParser = DataType.fromJson) match {
+ case s: DataType => s
+ case other => throw InvalidPlanInput(s"Invalid return type $other")
+ }
+ } else null
+ if (udf.getAggregate) {
+ session.udf.registerJavaUDAF(fun.getFunctionName, udf.getClassName)
+ } else {
+ session.udf.registerJava(fun.getFunctionName, udf.getClassName, dataType)
+ }
+ }
+
private def handleCommandPlugin(extension: ProtoAny): Unit = {
SparkConnectPluginRegistry.commandRegistry
// Lazily traverse the collection.
diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py
index 8c85f17bb5f..6334036fca4 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -47,6 +47,7 @@ from typing import (
Callable,
Generator,
Type,
+ TYPE_CHECKING,
)
import pandas as pd
@@ -69,6 +70,7 @@ from pyspark.errors.exceptions.connect import (
from pyspark.sql.connect.expressions import (
PythonUDF,
CommonInlineUserDefinedFunction,
+ JavaUDF,
)
from pyspark.sql.connect.types import parse_data_type
from pyspark.sql.types import (
@@ -80,6 +82,10 @@ from pyspark.serializers import CloudPickleSerializer
from pyspark.rdd import PythonEvalType
+if TYPE_CHECKING:
+ from pyspark.sql.connect._typing import DataTypeOrString
+
+
def _configure_logging() -> logging.Logger:
"""Configure logging for the Spark Connect clients."""
logger = logging.getLogger(__name__)
@@ -534,7 +540,7 @@ class SparkConnectClient(object):
def register_udf(
self,
function: Any,
- return_type: Union[str, DataType],
+ return_type: "DataTypeOrString",
name: Optional[str] = None,
eval_type: int = PythonEvalType.SQL_BATCHED_UDF,
deterministic: bool = True,
@@ -561,9 +567,9 @@ class SparkConnectClient(object):
# construct a CommonInlineUserDefinedFunction
fun = CommonInlineUserDefinedFunction(
function_name=name,
- deterministic=deterministic,
arguments=[],
function=py_udf,
+ deterministic=deterministic,
).to_plan_udf(self)
# construct the request
@@ -573,6 +579,35 @@ class SparkConnectClient(object):
self._execute(req)
return name
+ def register_java(
+ self,
+ name: str,
+ javaClassName: str,
+ return_type: Optional["DataTypeOrString"] = None,
+ aggregate: bool = False,
+ ) -> None:
+ # convert str return_type to DataType
+ if isinstance(return_type, str):
+ return_type = parse_data_type(return_type)
+
+ # construct a JavaUDF
+ if return_type is None:
+ java_udf = JavaUDF(class_name=javaClassName, aggregate=aggregate)
+ else:
+ java_udf = JavaUDF(
+ class_name=javaClassName,
+ output_type=return_type.json(),
+ )
+ fun = CommonInlineUserDefinedFunction(
+ function_name=name,
+ function=java_udf,
+ ).to_plan_judf(self)
+ # construct the request
+ req = self._execute_plan_request_with_metadata()
+ req.plan.command.register_function.CopyFrom(fun)
+
+ self._execute(req)
+
def _build_metrics(self, metrics: "pb2.ExecutePlanResponse.Metrics") -> List[PlanMetrics]:
return [
PlanMetrics(
diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py
index 2b1901167c1..0d059740032 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -19,6 +19,7 @@ from pyspark.sql.connect.utils import check_dependencies
check_dependencies(__name__, __file__)
from typing import (
+ cast,
TYPE_CHECKING,
Any,
Union,
@@ -520,6 +521,31 @@ class PythonUDF:
)
+class JavaUDF:
+ """Represents a Java (aggregate) user-defined function."""
+
+ def __init__(
+ self,
+ class_name: str,
+ output_type: Optional[str] = None,
+ aggregate: bool = False,
+ ) -> None:
+ self._class_name = class_name
+ self._output_type = output_type
+ self._aggregate = aggregate
+
+ def to_plan(self, session: "SparkConnectClient") -> proto.JavaUDF:
+ expr = proto.JavaUDF()
+ expr.class_name = self._class_name
+ if self._output_type is not None:
+ expr.output_type = self._output_type
+ expr.aggregate = self._aggregate
+ return expr
+
+ def __repr__(self) -> str:
+ return f"{self._class_name}, {self._output_type}"
+
+
class CommonInlineUserDefinedFunction(Expression):
"""Represents a user-defined function with an inlined defined function body of any programming
languages."""
@@ -527,9 +553,9 @@ class CommonInlineUserDefinedFunction(Expression):
def __init__(
self,
function_name: str,
- deterministic: bool,
- arguments: Sequence[Expression],
- function: PythonUDF,
+ function: Union[PythonUDF, JavaUDF],
+ deterministic: bool = False,
+ arguments: Sequence[Expression] = [],
):
self._function_name = function_name
self._deterministic = deterministic
@@ -545,7 +571,7 @@ class CommonInlineUserDefinedFunction(Expression):
[arg.to_plan(session) for arg in self._arguments]
)
expr.common_inline_user_defined_function.python_udf.CopyFrom(
- self._function.to_plan(session)
+ cast(proto.PythonUDF, self._function.to_plan(session))
)
return expr
@@ -557,7 +583,15 @@ class CommonInlineUserDefinedFunction(Expression):
expr.deterministic = self._deterministic
if len(self._arguments) > 0:
expr.arguments.extend([arg.to_plan(session) for arg in self._arguments])
- expr.python_udf.CopyFrom(self._function.to_plan(session))
+ expr.python_udf.CopyFrom(cast(proto.PythonUDF, self._function.to_plan(session)))
+ return expr
+
+ def to_plan_judf(
+ self, session: "SparkConnectClient"
+ ) -> "proto.CommonInlineUserDefinedFunction":
+ expr = proto.CommonInlineUserDefinedFunction()
+ expr.function_name = self._function_name
+ expr.java_udf.CopyFrom(cast(proto.JavaUDF, self._function.to_plan(session)))
return expr
def __repr__(self) -> str:
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py
index d0db2ad56cc..24dd1136480 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -34,7 +34,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xa8\'\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunc [...]
+ b'\n\x1fspark/connect/expressions.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x19spark/connect/types.proto"\xa8\'\n\nExpression\x12=\n\x07literal\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x07literal\x12\x62\n\x14unresolved_attribute\x18\x02 \x01(\x0b\x32-.spark.connect.Expression.UnresolvedAttributeH\x00R\x13unresolvedAttribute\x12_\n\x13unresolved_function\x18\x03 \x01(\x0b\x32,.spark.connect.Expression.UnresolvedFunctionH\x00R\x12unresolvedFunc [...]
)
@@ -67,6 +67,7 @@ _COMMONINLINEUSERDEFINEDFUNCTION = DESCRIPTOR.message_types_by_name[
]
_PYTHONUDF = DESCRIPTOR.message_types_by_name["PythonUDF"]
_SCALARSCALAUDF = DESCRIPTOR.message_types_by_name["ScalarScalaUDF"]
+_JAVAUDF = DESCRIPTOR.message_types_by_name["JavaUDF"]
_EXPRESSION_WINDOW_WINDOWFRAME_FRAMETYPE = _EXPRESSION_WINDOW_WINDOWFRAME.enum_types_by_name[
"FrameType"
]
@@ -306,6 +307,17 @@ ScalarScalaUDF = _reflection.GeneratedProtocolMessageType(
)
_sym_db.RegisterMessage(ScalarScalaUDF)
+JavaUDF = _reflection.GeneratedProtocolMessageType(
+ "JavaUDF",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _JAVAUDF,
+ "__module__": "spark.connect.expressions_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.JavaUDF)
+ },
+)
+_sym_db.RegisterMessage(JavaUDF)
+
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
@@ -357,9 +369,11 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_start = 5062
_EXPRESSION_UNRESOLVEDNAMEDLAMBDAVARIABLE._serialized_end = 5124
_COMMONINLINEUSERDEFINEDFUNCTION._serialized_start = 5140
- _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5451
- _PYTHONUDF._serialized_start = 5454
- _PYTHONUDF._serialized_end = 5584
- _SCALARSCALAUDF._serialized_start = 5587
- _SCALARSCALAUDF._serialized_end = 5771
+ _COMMONINLINEUSERDEFINEDFUNCTION._serialized_end = 5504
+ _PYTHONUDF._serialized_start = 5507
+ _PYTHONUDF._serialized_end = 5637
+ _SCALARSCALAUDF._serialized_start = 5640
+ _SCALARSCALAUDF._serialized_end = 5824
+ _JAVAUDF._serialized_start = 5826
+ _JAVAUDF._serialized_end = 5950
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.pyi b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
index 37db24ff91a..19b47c7ab91 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.pyi
@@ -1171,10 +1171,11 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
ARGUMENTS_FIELD_NUMBER: builtins.int
PYTHON_UDF_FIELD_NUMBER: builtins.int
SCALAR_SCALA_UDF_FIELD_NUMBER: builtins.int
+ JAVA_UDF_FIELD_NUMBER: builtins.int
function_name: builtins.str
"""(Required) Name of the user-defined function."""
deterministic: builtins.bool
- """(Required) Indicate if the user-defined function is deterministic."""
+ """(Optional) Indicate if the user-defined function is deterministic."""
@property
def arguments(
self,
@@ -1184,6 +1185,8 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
def python_udf(self) -> global___PythonUDF: ...
@property
def scalar_scala_udf(self) -> global___ScalarScalaUDF: ...
+ @property
+ def java_udf(self) -> global___JavaUDF: ...
def __init__(
self,
*,
@@ -1192,12 +1195,15 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
arguments: collections.abc.Iterable[global___Expression] | None = ...,
python_udf: global___PythonUDF | None = ...,
scalar_scala_udf: global___ScalarScalaUDF | None = ...,
+ java_udf: global___JavaUDF | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"function",
b"function",
+ "java_udf",
+ b"java_udf",
"python_udf",
b"python_udf",
"scalar_scala_udf",
@@ -1215,6 +1221,8 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
b"function",
"function_name",
b"function_name",
+ "java_udf",
+ b"java_udf",
"python_udf",
b"python_udf",
"scalar_scala_udf",
@@ -1223,7 +1231,7 @@ class CommonInlineUserDefinedFunction(google.protobuf.message.Message):
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["function", b"function"]
- ) -> typing_extensions.Literal["python_udf", "scalar_scala_udf"] | None: ...
+ ) -> typing_extensions.Literal["python_udf", "scalar_scala_udf", "java_udf"] | None: ...
global___CommonInlineUserDefinedFunction = CommonInlineUserDefinedFunction
@@ -1314,3 +1322,47 @@ class ScalarScalaUDF(google.protobuf.message.Message):
) -> None: ...
global___ScalarScalaUDF = ScalarScalaUDF
+
+class JavaUDF(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ CLASS_NAME_FIELD_NUMBER: builtins.int
+ OUTPUT_TYPE_FIELD_NUMBER: builtins.int
+ AGGREGATE_FIELD_NUMBER: builtins.int
+ class_name: builtins.str
+ """(Required) Fully qualified name of Java class"""
+ output_type: builtins.str
+ """(Optional) Output type of the Java UDF"""
+ aggregate: builtins.bool
+ """(Required) Indicate if the Java user-defined function is an aggregate function"""
+ def __init__(
+ self,
+ *,
+ class_name: builtins.str = ...,
+ output_type: builtins.str | None = ...,
+ aggregate: builtins.bool = ...,
+ ) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_output_type", b"_output_type", "output_type", b"output_type"
+ ],
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_output_type",
+ b"_output_type",
+ "aggregate",
+ b"aggregate",
+ "class_name",
+ b"class_name",
+ "output_type",
+ b"output_type",
+ ],
+ ) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_output_type", b"_output_type"]
+ ) -> typing_extensions.Literal["output_type"] | None: ...
+
+global___JavaUDF = JavaUDF
diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py
index c6bff4a3caa..03e53cbd89e 100644
--- a/python/pyspark/sql/connect/udf.py
+++ b/python/pyspark/sql/connect/udf.py
@@ -127,9 +127,9 @@ class UserDefinedFunction:
)
return CommonInlineUserDefinedFunction(
function_name=self._name,
+ function=py_udf,
deterministic=self.deterministic,
arguments=arg_exprs,
- function=py_udf,
)
def __call__(self, *cols: "ColumnOrName") -> Column:
@@ -232,3 +232,18 @@ class UDFRegistration:
return return_udf
register.__doc__ = PySparkUDFRegistration.register.__doc__
+
+ def registerJavaFunction(
+ self,
+ name: str,
+ javaClassName: str,
+ returnType: Optional["DataTypeOrString"] = None,
+ ) -> None:
+ self.sparkSession._client.register_java(name, javaClassName, returnType)
+
+ registerJavaFunction.__doc__ = PySparkUDFRegistration.registerJavaFunction.__doc__
+
+ def registerJavaUDAF(self, name: str, javaClassName: str) -> None:
+ self.sparkSession._client.register_java(name, javaClassName, aggregate=True)
+
+ registerJavaUDAF.__doc__ = PySparkUDFRegistration.registerJavaUDAF.__doc__
diff --git a/python/pyspark/sql/tests/connect/test_parity_udf.py b/python/pyspark/sql/tests/connect/test_parity_udf.py
index 293f4b0f41a..b38b4c28a25 100644
--- a/python/pyspark/sql/tests/connect/test_parity_udf.py
+++ b/python/pyspark/sql/tests/connect/test_parity_udf.py
@@ -25,6 +25,7 @@ if should_test_connect:
sql.udf.UserDefinedFunction = UserDefinedFunction
+from pyspark.errors import AnalysisException
from pyspark.sql.tests.test_udf import BaseUDFTestsMixin
from pyspark.testing.connectutils import ReusedConnectTestCase
from pyspark.sql.types import IntegerType
@@ -103,30 +104,13 @@ class UDFParityTests(BaseUDFTestsMixin, ReusedConnectTestCase):
def test_udf_registration_return_type_none(self):
super().test_udf_registration_return_type_none()
- # TODO(SPARK-42210): implement `spark.udf`
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_non_existed_udaf(self):
- super().test_non_existed_udaf()
-
- # TODO(SPARK-42210): implement `spark.udf`
- @unittest.skip("Fails in Spark Connect, should enable.")
def test_non_existed_udf(self):
- super().test_non_existed_udf()
-
- # TODO(SPARK-42210): implement `spark.udf`
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_register_java_function(self):
- super().test_register_java_function()
-
- # TODO(SPARK-42210): implement `spark.udf`
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_register_java_udaf(self):
- super().test_register_java_udaf()
-
- # TODO(SPARK-42210): implement `spark.udf`
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_udf_in_left_outer_join_condition(self):
- super().test_udf_in_left_outer_join_condition()
+ spark = self.spark
+ self.assertRaisesRegex(
+ AnalysisException,
+ "Can not load class non_existed_udf",
+ lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf"),
+ )
def test_udf_registration_returns_udf(self):
df = self.spark.range(10)
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 9f8e3e46977..0b9b082ade3 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -622,6 +622,9 @@ class UDFRegistration:
.. versionadded:: 2.3.0
+ .. versionchanged:: 3.4.0
+ Support Spark Connect.
+
Parameters
----------
name : str
@@ -666,6 +669,9 @@ class UDFRegistration:
.. versionadded:: 2.3.0
+ .. versionchanged:: 3.4.0
+ Support Spark Connect.
+
name : str
name of the user-defined aggregate function
javaClassName : str
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org