You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/02/13 05:10:47 UTC

[spark] branch master updated: [SPARK-42269][CONNECT][PYTHON] Support complex return types in DDL strings

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3985b91633f [SPARK-42269][CONNECT][PYTHON] Support complex return types in DDL strings
3985b91633f is described below

commit 3985b91633f5e49c8c97433651f81604dad193e9
Author: Xinrong Meng <xi...@apache.org>
AuthorDate: Mon Feb 13 14:10:34 2023 +0900

    [SPARK-42269][CONNECT][PYTHON] Support complex return types in DDL strings
    
    ### What changes were proposed in this pull request?
    Support complex return types in DDL strings.
    
    ### Why are the changes needed?
    Parity with vanilla PySpark.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes.
    
    ```py
    # BEFORE
    >>> spark.range(2).select(udf(lambda x: (x, x), "struct<x:integer, y:integer>")("id"))
    ...
    AssertionError: returnType should be singular
    
    >>> spark.udf.register('f', lambda x: (x, x), "struct<x:integer, y:integer>")
    ...
    AssertionError: returnType should be singular
    
    # AFTER
    >>> spark.range(2).select(udf(lambda x: (x, x), "struct<x:integer, y:integer>")("id"))
    DataFrame[<lambda>(id): struct<x:int,y:int>]
    
    >>> spark.udf.register('f', lambda x: (x, x), "struct<x:integer, y:integer>")
    <function <lambda> at 0x7faee0eaaca0>
    
    ```
    
    ### How was this patch tested?
    Unit tests.
    
    Closes #39964 from xinrong-meng/collection_ret_type.
    
    Authored-by: Xinrong Meng <xi...@apache.org>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 python/pyspark/sql/connect/client.py               | 15 ++------------
 python/pyspark/sql/connect/types.py                | 19 ++++++++++++++++++
 python/pyspark/sql/connect/udf.py                  | 23 ++++------------------
 .../pyspark/sql/tests/connect/test_parity_udf.py   |  5 -----
 4 files changed, 25 insertions(+), 37 deletions(-)

diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py
index 943a7e70464..2c07596fec0 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -68,12 +68,12 @@ from pyspark.sql.connect.expressions import (
     PythonUDF,
     CommonInlineUserDefinedFunction,
 )
+from pyspark.sql.connect.types import parse_data_type
 from pyspark.sql.types import (
     DataType,
     StructType,
     StructField,
 )
-from pyspark.sql.utils import is_remote
 from pyspark.serializers import CloudPickleSerializer
 from pyspark.rdd import PythonEvalType
 
@@ -443,23 +443,12 @@ class SparkConnectClient(object):
         """Create a temporary UDF in the session catalog on the other side. We generate a
         temporary name for it."""
 
-        from pyspark.sql import SparkSession as PySparkSession
-
         if name is None:
             name = f"fun_{uuid.uuid4().hex}"
 
         # convert str return_type to DataType
         if isinstance(return_type, str):
-
-            assert is_remote()
-            return_type_schema = (  # a workaround to parse the DataType from DDL strings
-                PySparkSession.builder.getOrCreate()
-                .createDataFrame(data=[], schema=return_type)
-                .schema
-            )
-            assert len(return_type_schema.fields) == 1, "returnType should be singular"
-            return_type = return_type_schema.fields[0].dataType
-
+            return_type = parse_data_type(return_type)
         # construct a PythonUDF
         py_udf = PythonUDF(
             output_type=return_type.json(),
diff --git a/python/pyspark/sql/connect/types.py b/python/pyspark/sql/connect/types.py
index f12d6c4827e..6b9975c52cd 100644
--- a/python/pyspark/sql/connect/types.py
+++ b/python/pyspark/sql/connect/types.py
@@ -51,6 +51,7 @@ from pyspark.sql.types import (
 )
 
 import pyspark.sql.connect.proto as pb2
+from pyspark.sql.utils import is_remote
 
 
 JVM_BYTE_MIN: int = -(1 << 7)
@@ -337,3 +338,21 @@ def from_arrow_schema(arrow_schema: "pa.Schema") -> StructType:
             for field in arrow_schema
         ]
     )
+
+
+def parse_data_type(data_type: str) -> DataType:
+    # Currently we don't have a way to have a current Spark session in Spark Connect, and
+    # pyspark.sql.SparkSession has a centralized logic to control the session creation.
+    # So uses pyspark.sql.SparkSession for now. Should replace this to using the current
+    # Spark session for Spark Connect in the future.
+    from pyspark.sql import SparkSession as PySparkSession
+
+    assert is_remote()
+    return_type_schema = (
+        PySparkSession.builder.getOrCreate().createDataFrame(data=[], schema=data_type).schema
+    )
+    if len(return_type_schema.fields) == 1:
+        return_type = return_type_schema.fields[0].dataType
+    else:
+        return_type = return_type_schema
+    return return_type
diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py
index 573d8f582e2..39c31e85992 100644
--- a/python/pyspark/sql/connect/udf.py
+++ b/python/pyspark/sql/connect/udf.py
@@ -33,9 +33,9 @@ from pyspark.sql.connect.expressions import (
     CommonInlineUserDefinedFunction,
 )
 from pyspark.sql.connect.column import Column
+from pyspark.sql.connect.types import parse_data_type
 from pyspark.sql.types import DataType, StringType
 from pyspark.sql.udf import UDFRegistration as PySparkUDFRegistration
-from pyspark.sql.utils import is_remote
 
 
 if TYPE_CHECKING:
@@ -99,24 +99,9 @@ class UserDefinedFunction:
             )
 
         self.func = func
-
-        if isinstance(returnType, str):
-            # Currently we don't have a way to have a current Spark session in Spark Connect, and
-            # pyspark.sql.SparkSession has a centralized logic to control the session creation.
-            # So uses pyspark.sql.SparkSession for now. Should replace this to using the current
-            # Spark session for Spark Connect in the future.
-            from pyspark.sql import SparkSession as PySparkSession
-
-            assert is_remote()
-            return_type_schema = (  # a workaround to parse the DataType from DDL strings
-                PySparkSession.builder.getOrCreate()
-                .createDataFrame(data=[], schema=returnType)
-                .schema
-            )
-            assert len(return_type_schema.fields) == 1, "returnType should be singular"
-            self._returnType = return_type_schema.fields[0].dataType
-        else:
-            self._returnType = returnType
+        self._returnType = (
+            parse_data_type(returnType) if isinstance(returnType, str) else returnType
+        )
         self._name = name or (
             func.__name__ if hasattr(func, "__name__") else func.__class__.__name__
         )
diff --git a/python/pyspark/sql/tests/connect/test_parity_udf.py b/python/pyspark/sql/tests/connect/test_parity_udf.py
index b35f55febf2..160f06d37f7 100644
--- a/python/pyspark/sql/tests/connect/test_parity_udf.py
+++ b/python/pyspark/sql/tests/connect/test_parity_udf.py
@@ -165,11 +165,6 @@ class UDFParityTests(BaseUDFTestsMixin, ReusedConnectTestCase):
     def test_udf_in_left_outer_join_condition(self):
         super().test_udf_in_left_outer_join_condition()
 
-    # TODO(SPARK-42269): support return type as a collection DataType in DDL strings
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_udf_with_string_return_type(self):
-        super().test_udf_with_string_return_type()
-
     def test_udf_registration_returns_udf(self):
         df = self.spark.range(10)
         add_three = self.spark.udf.register("add_three", lambda x: x + 3, IntegerType())


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org