You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by xi...@apache.org on 2023/06/06 22:48:39 UTC

[spark] branch master updated: [SPARK-43893][PYTHON][CONNECT] Non-atomic data type support in Arrow-optimized Python UDF

This is an automated email from the ASF dual-hosted git repository.

xinrong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 94098853592 [SPARK-43893][PYTHON][CONNECT] Non-atomic data type support in Arrow-optimized Python UDF
94098853592 is described below

commit 94098853592b524f52e9a340166b96ddeda4e898
Author: Xinrong Meng <xi...@apache.org>
AuthorDate: Tue Jun 6 15:48:14 2023 -0700

    [SPARK-43893][PYTHON][CONNECT] Non-atomic data type support in Arrow-optimized Python UDF
    
    ### What changes were proposed in this pull request?
    Support non-atomic data types in input and output of Arrow-optimized Python UDF.
    
    Non-atomic data types refer to: ArrayType, MapType, and StructType.
    
    ### Why are the changes needed?
    Parity with pickled Python UDFs.
    
    ### Does this PR introduce _any_ user-facing change?
    Non-atomic data types are accepted as both input and output of Arrow-optimized Python UDF.
    
    For example,
    ```py
    >>> df = spark.range(1).selectExpr("struct(1, struct('John', 30, ('value', 10))) as nested_struct")
    >>> df.select(udf(lambda x: str(x))("nested_struct")).first()
    Row(<lambda>(nested_struct)="Row(col1=1, col2=Row(col1='John', col2=30, col3=Row(col1='value', col2=10)))")
    ```
    
    ### How was this patch tested?
    Unit tests.
    
    Closes #41321 from xinrong-meng/arrow_udf_struct.
    
    Authored-by: Xinrong Meng <xi...@apache.org>
    Signed-off-by: Xinrong Meng <xi...@apache.org>
---
 python/pyspark/sql/pandas/serializers.py          | 22 ++++++++---
 python/pyspark/sql/tests/test_arrow_python_udf.py | 17 ++++-----
 python/pyspark/sql/tests/test_udf.py              | 45 +++++++++++++++++++++++
 python/pyspark/sql/udf.py                         | 15 +-------
 python/pyspark/worker.py                          | 13 +++++--
 5 files changed, 79 insertions(+), 33 deletions(-)

diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py
index 84471143367..12d0bee88ad 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -172,7 +172,7 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
         self._timezone = timezone
         self._safecheck = safecheck
 
-    def arrow_to_pandas(self, arrow_column):
+    def arrow_to_pandas(self, arrow_column, struct_in_pandas="dict"):
         # If the given column is a date type column, creates a series of datetime.date directly
         # instead of creating datetime64[ns] as intermediate data to avoid overflow caused by
         # datetime64[ns] type handling.
@@ -184,7 +184,7 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
             data_type=from_arrow_type(arrow_column.type, prefer_timestamp_ntz=True),
             nullable=True,
             timezone=self._timezone,
-            struct_in_pandas="dict",
+            struct_in_pandas=struct_in_pandas,
             error_on_duplicated_field_names=True,
         )
         return converter(s)
@@ -310,10 +310,18 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
     Serializer used by Python worker to evaluate Pandas UDFs
     """
 
-    def __init__(self, timezone, safecheck, assign_cols_by_name, df_for_struct=False):
+    def __init__(
+        self,
+        timezone,
+        safecheck,
+        assign_cols_by_name,
+        df_for_struct=False,
+        struct_in_pandas="dict",
+    ):
         super(ArrowStreamPandasUDFSerializer, self).__init__(timezone, safecheck)
         self._assign_cols_by_name = assign_cols_by_name
         self._df_for_struct = df_for_struct
+        self._struct_in_pandas = struct_in_pandas
 
     def arrow_to_pandas(self, arrow_column):
         import pyarrow.types as types
@@ -323,13 +331,15 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
 
             series = [
                 super(ArrowStreamPandasUDFSerializer, self)
-                .arrow_to_pandas(column)
+                .arrow_to_pandas(column, self._struct_in_pandas)
                 .rename(field.name)
                 for column, field in zip(arrow_column.flatten(), arrow_column.type)
             ]
             s = pd.concat(series, axis=1)
         else:
-            s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column)
+            s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(
+                arrow_column, self._struct_in_pandas
+            )
         return s
 
     def _create_batch(self, series):
@@ -360,7 +370,7 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
 
         arrs = []
         for s, t in series:
-            if t is not None and pa.types.is_struct(t):
+            if self._struct_in_pandas == "dict" and t is not None and pa.types.is_struct(t):
                 if not isinstance(s, pd.DataFrame):
                     raise PySparkValueError(
                         "A field of type StructType expects a pandas.DataFrame, "
diff --git a/python/pyspark/sql/tests/test_arrow_python_udf.py b/python/pyspark/sql/tests/test_arrow_python_udf.py
index 3266168f290..c60a7ef648a 100644
--- a/python/pyspark/sql/tests/test_arrow_python_udf.py
+++ b/python/pyspark/sql/tests/test_arrow_python_udf.py
@@ -45,23 +45,19 @@ class PythonUDFArrowTestsMixin(BaseUDFTestsMixin):
     def test_register_java_udaf(self):
         super(PythonUDFArrowTests, self).test_register_java_udaf()
 
-    @unittest.skip("Struct input types are not supported with Arrow optimization")
-    def test_udf_input_serialization_valuecompare_disabled(self):
-        super(PythonUDFArrowTests, self).test_udf_input_serialization_valuecompare_disabled()
-
-    def test_nested_input_error(self):
-        with self.assertRaisesRegexp(Exception, "[NotImplementedError]"):
-            self.spark.range(1).selectExpr("struct(1, 2) as struct").select(
-                udf(lambda x: x)("struct")
-            ).collect()
+    # TODO(SPARK-43903): Standardize ArrayType conversion for Python UDF
+    @unittest.skip("Inconsistent ArrayType conversion with/without Arrow.")
+    def test_nested_array(self):
+        super(PythonUDFArrowTests, self).test_nested_array()
 
     def test_complex_input_types(self):
         row = (
             self.spark.range(1)
-            .selectExpr("array(1, 2, 3) as array", "map('a', 'b') as map")
+            .selectExpr("array(1, 2, 3) as array", "map('a', 'b') as map", "struct(1, 2) as struct")
             .select(
                 udf(lambda x: str(x))("array"),
                 udf(lambda x: str(x))("map"),
+                udf(lambda x: str(x))("struct"),
             )
             .first()
         )
@@ -69,6 +65,7 @@ class PythonUDFArrowTestsMixin(BaseUDFTestsMixin):
         # The input is NumPy array when the optimization is on.
         self.assertEquals(row[0], "[1 2 3]")
         self.assertEquals(row[1], "{'a': 'b'}")
+        self.assertEquals(row[2], "Row(col1=1, col2=2)")
 
     def test_use_arrow(self):
         # useArrow=True
diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py
index 3e33040cc70..8ffcb5e05a2 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -837,6 +837,51 @@ class BaseUDFTestsMixin(object):
             len(self.spark.range(10).select(udf(lambda x: x, DoubleType())(rand())).collect()), 10
         )
 
+    def test_nested_struct(self):
+        df = self.spark.range(1).selectExpr(
+            "struct(1, struct('John', 30, ('value', 10))) as nested_struct"
+        )
+        # Input
+        row = df.select(udf(lambda x: str(x))("nested_struct")).first()
+        self.assertEquals(
+            row[0], "Row(col1=1, col2=Row(col1='John', col2=30, col3=Row(col1='value', col2=10)))"
+        )
+        # Output
+        row = df.select(udf(lambda x: x, returnType=df.dtypes[0][1])("nested_struct")).first()
+        self.assertEquals(
+            row[0], Row(col1=1, col2=Row(col1="John", col2=30, col3=Row(col1="value", col2=10)))
+        )
+
+    def test_nested_map(self):
+        df = self.spark.range(1).selectExpr("map('a', map('b', 'c')) as nested_map")
+        # Input
+        row = df.select(udf(lambda x: str(x))("nested_map")).first()
+        self.assertEquals(row[0], "{'a': {'b': 'c'}}")
+        # Output
+
+        @udf(returnType=df.dtypes[0][1])
+        def f(x):
+            x["a"]["b"] = "d"
+            return x
+
+        row = df.select(f("nested_map")).first()
+        self.assertEquals(row[0], {"a": {"b": "d"}})
+
+    def test_nested_array(self):
+        df = self.spark.range(1).selectExpr("array(array(1, 2), array(3, 4)) as nested_array")
+        # Input
+        row = df.select(udf(lambda x: str(x))("nested_array")).first()
+        self.assertEquals(row[0], "[[1, 2], [3, 4]]")
+        # Output
+
+        @udf(returnType=df.dtypes[0][1])
+        def f(x):
+            x.append([4, 5])
+            return x
+
+        row = df.select(f("nested_array")).first()
+        self.assertEquals(row[0], [[1, 2], [3, 4], [4, 5]])
+
 
 class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase):
     @classmethod
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 87d53266edf..c6171ffece9 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -32,10 +32,8 @@ from pyspark.profiler import Profiler
 from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType
 from pyspark.sql.column import Column, _to_java_column, _to_java_expr, _to_seq
 from pyspark.sql.types import (
-    ArrayType,
     BinaryType,
     DataType,
-    MapType,
     StringType,
     StructType,
     _parse_datatype_string,
@@ -129,18 +127,12 @@ def _create_py_udf(
             else useArrow
         )
     regular_udf = _create_udf(f, returnType, PythonEvalType.SQL_BATCHED_UDF)
-    return_type = regular_udf.returnType
     try:
         is_func_with_args = len(getfullargspec(f).args) > 0
     except TypeError:
         is_func_with_args = False
-    is_output_atomic_type = (
-        not isinstance(return_type, StructType)
-        and not isinstance(return_type, MapType)
-        and not isinstance(return_type, ArrayType)
-    )
     if is_arrow_enabled:
-        if is_output_atomic_type and is_func_with_args:
+        if is_func_with_args:
             return _create_arrow_py_udf(regular_udf)
         else:
             warnings.warn(
@@ -175,11 +167,6 @@ def _create_arrow_py_udf(regular_udf):  # type: ignore
         result_func = lambda r: bytes(r) if r is not None else r  # noqa: E731
 
     def vectorized_udf(*args: pd.Series) -> pd.Series:
-        if any(map(lambda arg: isinstance(arg, pd.DataFrame), args)):
-            raise PySparkNotImplementedError(
-                error_class="UNSUPPORTED_WITH_ARROW_OPTIMIZATION",
-                message_parameters={"feature": "Struct input type"},
-            )
         return pd.Series(result_func(f(*a)) for a in zip(*args))
 
     # Regular UDFs can take callable instances too.
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 9bd8df077b6..06f0d1dc37f 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -511,13 +511,20 @@ def read_udfs(pickleSer, infile, eval_type):
             # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of
             # pandas Series. See SPARK-27240.
             df_for_struct = (
-                eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
-                or eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF
+                eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF
                 or eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
                 or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
             )
+            # Arrow-optimized Python UDF takes a struct type argument as a Row
+            struct_in_pandas = (
+                "row" if eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF else "dict"
+            )
             ser = ArrowStreamPandasUDFSerializer(
-                timezone, safecheck, assign_cols_by_name(runner_conf), df_for_struct
+                timezone,
+                safecheck,
+                assign_cols_by_name(runner_conf),
+                df_for_struct,
+                struct_in_pandas,
             )
     else:
         ser = BatchedSerializer(CPickleSerializer(), 100)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org