You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/02/13 05:10:55 UTC
[spark] branch branch-3.4 updated: [SPARK-42269][CONNECT][PYTHON] Support complex return types in DDL strings
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new 95919e26993 [SPARK-42269][CONNECT][PYTHON] Support complex return types in DDL strings
95919e26993 is described below
commit 95919e269930f3d1f3716b869e1abb25185d8a44
Author: Xinrong Meng <xi...@apache.org>
AuthorDate: Mon Feb 13 14:10:34 2023 +0900
[SPARK-42269][CONNECT][PYTHON] Support complex return types in DDL strings
### What changes were proposed in this pull request?
Support complex return types in DDL strings.
### Why are the changes needed?
Parity with vanilla PySpark.
### Does this PR introduce _any_ user-facing change?
Yes.
```py
# BEFORE
>>> spark.range(2).select(udf(lambda x: (x, x), "struct<x:integer, y:integer>")("id"))
...
AssertionError: returnType should be singular
>>> spark.udf.register('f', lambda x: (x, x), "struct<x:integer, y:integer>")
...
AssertionError: returnType should be singular
# AFTER
>>> spark.range(2).select(udf(lambda x: (x, x), "struct<x:integer, y:integer>")("id"))
DataFrame[<lambda>(id): struct<x:int,y:int>]
>>> spark.udf.register('f', lambda x: (x, x), "struct<x:integer, y:integer>")
<function <lambda> at 0x7faee0eaaca0>
```
### How was this patch tested?
Unit tests.
Closes #39964 from xinrong-meng/collection_ret_type.
Authored-by: Xinrong Meng <xi...@apache.org>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
(cherry picked from commit 3985b91633f5e49c8c97433651f81604dad193e9)
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
python/pyspark/sql/connect/client.py | 15 ++------------
python/pyspark/sql/connect/types.py | 19 ++++++++++++++++++
python/pyspark/sql/connect/udf.py | 23 ++++------------------
.../pyspark/sql/tests/connect/test_parity_udf.py | 5 -----
4 files changed, 25 insertions(+), 37 deletions(-)
diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py
index 943a7e70464..2c07596fec0 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -68,12 +68,12 @@ from pyspark.sql.connect.expressions import (
PythonUDF,
CommonInlineUserDefinedFunction,
)
+from pyspark.sql.connect.types import parse_data_type
from pyspark.sql.types import (
DataType,
StructType,
StructField,
)
-from pyspark.sql.utils import is_remote
from pyspark.serializers import CloudPickleSerializer
from pyspark.rdd import PythonEvalType
@@ -443,23 +443,12 @@ class SparkConnectClient(object):
"""Create a temporary UDF in the session catalog on the other side. We generate a
temporary name for it."""
- from pyspark.sql import SparkSession as PySparkSession
-
if name is None:
name = f"fun_{uuid.uuid4().hex}"
# convert str return_type to DataType
if isinstance(return_type, str):
-
- assert is_remote()
- return_type_schema = ( # a workaround to parse the DataType from DDL strings
- PySparkSession.builder.getOrCreate()
- .createDataFrame(data=[], schema=return_type)
- .schema
- )
- assert len(return_type_schema.fields) == 1, "returnType should be singular"
- return_type = return_type_schema.fields[0].dataType
-
+ return_type = parse_data_type(return_type)
# construct a PythonUDF
py_udf = PythonUDF(
output_type=return_type.json(),
diff --git a/python/pyspark/sql/connect/types.py b/python/pyspark/sql/connect/types.py
index f12d6c4827e..6b9975c52cd 100644
--- a/python/pyspark/sql/connect/types.py
+++ b/python/pyspark/sql/connect/types.py
@@ -51,6 +51,7 @@ from pyspark.sql.types import (
)
import pyspark.sql.connect.proto as pb2
+from pyspark.sql.utils import is_remote
JVM_BYTE_MIN: int = -(1 << 7)
@@ -337,3 +338,21 @@ def from_arrow_schema(arrow_schema: "pa.Schema") -> StructType:
for field in arrow_schema
]
)
+
+
+def parse_data_type(data_type: str) -> DataType:
+ # Currently we don't have a way to have a current Spark session in Spark Connect, and
+ # pyspark.sql.SparkSession has a centralized logic to control the session creation.
+ # So uses pyspark.sql.SparkSession for now. Should replace this to using the current
+ # Spark session for Spark Connect in the future.
+ from pyspark.sql import SparkSession as PySparkSession
+
+ assert is_remote()
+ return_type_schema = (
+ PySparkSession.builder.getOrCreate().createDataFrame(data=[], schema=data_type).schema
+ )
+ if len(return_type_schema.fields) == 1:
+ return_type = return_type_schema.fields[0].dataType
+ else:
+ return_type = return_type_schema
+ return return_type
diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py
index 573d8f582e2..39c31e85992 100644
--- a/python/pyspark/sql/connect/udf.py
+++ b/python/pyspark/sql/connect/udf.py
@@ -33,9 +33,9 @@ from pyspark.sql.connect.expressions import (
CommonInlineUserDefinedFunction,
)
from pyspark.sql.connect.column import Column
+from pyspark.sql.connect.types import parse_data_type
from pyspark.sql.types import DataType, StringType
from pyspark.sql.udf import UDFRegistration as PySparkUDFRegistration
-from pyspark.sql.utils import is_remote
if TYPE_CHECKING:
@@ -99,24 +99,9 @@ class UserDefinedFunction:
)
self.func = func
-
- if isinstance(returnType, str):
- # Currently we don't have a way to have a current Spark session in Spark Connect, and
- # pyspark.sql.SparkSession has a centralized logic to control the session creation.
- # So uses pyspark.sql.SparkSession for now. Should replace this to using the current
- # Spark session for Spark Connect in the future.
- from pyspark.sql import SparkSession as PySparkSession
-
- assert is_remote()
- return_type_schema = ( # a workaround to parse the DataType from DDL strings
- PySparkSession.builder.getOrCreate()
- .createDataFrame(data=[], schema=returnType)
- .schema
- )
- assert len(return_type_schema.fields) == 1, "returnType should be singular"
- self._returnType = return_type_schema.fields[0].dataType
- else:
- self._returnType = returnType
+ self._returnType = (
+ parse_data_type(returnType) if isinstance(returnType, str) else returnType
+ )
self._name = name or (
func.__name__ if hasattr(func, "__name__") else func.__class__.__name__
)
diff --git a/python/pyspark/sql/tests/connect/test_parity_udf.py b/python/pyspark/sql/tests/connect/test_parity_udf.py
index b35f55febf2..160f06d37f7 100644
--- a/python/pyspark/sql/tests/connect/test_parity_udf.py
+++ b/python/pyspark/sql/tests/connect/test_parity_udf.py
@@ -165,11 +165,6 @@ class UDFParityTests(BaseUDFTestsMixin, ReusedConnectTestCase):
def test_udf_in_left_outer_join_condition(self):
super().test_udf_in_left_outer_join_condition()
- # TODO(SPARK-42269): support return type as a collection DataType in DDL strings
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_udf_with_string_return_type(self):
- super().test_udf_with_string_return_type()
-
def test_udf_registration_returns_udf(self):
df = self.spark.range(10)
add_three = self.spark.udf.register("add_three", lambda x: x + 3, IntegerType())
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org