You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/06/21 09:22:19 UTC

[spark] branch master updated: [SPARK-43903][PYTHON][CONNECT] Improve ArrayType input support in Arrow Python UDF

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ac1e2223105 [SPARK-43903][PYTHON][CONNECT] Improve ArrayType input support in Arrow Python UDF
ac1e2223105 is described below

commit ac1e22231055d7e59eec5dd8c6a807252aab8b7f
Author: Xinrong Meng <xi...@apache.org>
AuthorDate: Wed Jun 21 17:22:00 2023 +0800

    [SPARK-43903][PYTHON][CONNECT] Improve ArrayType input support in Arrow Python UDF
    
    ### What changes were proposed in this pull request?
    Improve ArrayType input support in Arrow Python UDF.
    
    Previously, ArrayType is mapped to a 'np.array'; now it is mapped to a `list` following Pickled Python UDF.
    
    ### Why are the changes needed?
    Reach parity with Pickled Python UDF.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes.
    
    FROM
    ```py
    >>> df = spark.range(1).selectExpr("array(array(1, 2), array(3, 4)) as nested_array")
    >>> df.select(udf(lambda x: str(x), returnType='string', useArrow=True)("nested_array")).first()
    Row(<lambda>(nested_array)='[array([1, 2], dtype=int32) array([3, 4], dtype=int32)]')
    ```
    
    TO
    ```py
    >>> df = spark.range(1).selectExpr("array(array(1, 2), array(3, 4)) as nested_array"
    >>> df.select(udf(lambda x: str(x), returnType='string', useArrow=True)("nested_array")).first()
    Row(<lambda>(nested_array)='[[1, 2], [3, 4]]')
    ```
    
    ### How was this patch tested?
    Unit tests.
    
    Closes #41603 from xinrong-meng/ndarr.
    
    Authored-by: Xinrong Meng <xi...@apache.org>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 python/pyspark/sql/pandas/serializers.py          |  9 ++--
 python/pyspark/sql/pandas/types.py                | 61 ++++++++++++++++-------
 python/pyspark/sql/tests/test_arrow_python_udf.py | 17 +++++--
 python/pyspark/worker.py                          |  2 +
 4 files changed, 62 insertions(+), 27 deletions(-)

diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py
index 12d0bee88ad..307fcc33752 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -172,7 +172,7 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
         self._timezone = timezone
         self._safecheck = safecheck
 
-    def arrow_to_pandas(self, arrow_column, struct_in_pandas="dict"):
+    def arrow_to_pandas(self, arrow_column, struct_in_pandas="dict", ndarray_as_list=False):
         # If the given column is a date type column, creates a series of datetime.date directly
         # instead of creating datetime64[ns] as intermediate data to avoid overflow caused by
         # datetime64[ns] type handling.
@@ -186,6 +186,7 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
             timezone=self._timezone,
             struct_in_pandas=struct_in_pandas,
             error_on_duplicated_field_names=True,
+            ndarray_as_list=ndarray_as_list,
         )
         return converter(s)
 
@@ -317,11 +318,13 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
         assign_cols_by_name,
         df_for_struct=False,
         struct_in_pandas="dict",
+        ndarray_as_list=False,
     ):
         super(ArrowStreamPandasUDFSerializer, self).__init__(timezone, safecheck)
         self._assign_cols_by_name = assign_cols_by_name
         self._df_for_struct = df_for_struct
         self._struct_in_pandas = struct_in_pandas
+        self._ndarray_as_list = ndarray_as_list
 
     def arrow_to_pandas(self, arrow_column):
         import pyarrow.types as types
@@ -331,14 +334,14 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
 
             series = [
                 super(ArrowStreamPandasUDFSerializer, self)
-                .arrow_to_pandas(column, self._struct_in_pandas)
+                .arrow_to_pandas(column, self._struct_in_pandas, self._ndarray_as_list)
                 .rename(field.name)
                 for column, field in zip(arrow_column.flatten(), arrow_column.type)
             ]
             s = pd.concat(series, axis=1)
         else:
             s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(
-                arrow_column, self._struct_in_pandas
+                arrow_column, self._struct_in_pandas, self._ndarray_as_list
             )
         return s
 
diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py
index 757deff6130..53362047604 100644
--- a/python/pyspark/sql/pandas/types.py
+++ b/python/pyspark/sql/pandas/types.py
@@ -494,6 +494,7 @@ def _create_converter_to_pandas(
     struct_in_pandas: Optional[str] = None,
     error_on_duplicated_field_names: bool = True,
     timestamp_utc_localized: bool = True,
+    ndarray_as_list: bool = False,
 ) -> Callable[["pd.Series"], "pd.Series"]:
     """
     Create a converter of pandas Series that is created from Spark's Python objects,
@@ -520,6 +521,8 @@ def _create_converter_to_pandas(
         Whether the timestamp values are localized to UTC or not.
         The timestamp values from Arrow are localized to UTC,
         whereas the ones from `df.collect()` are localized to the local timezone.
+    ndarray_as_list : bool, optional
+        Whether `np.ndarray` is converted to a list or not (default ``False``).
 
     Returns
     -------
@@ -569,30 +572,47 @@ def _create_converter_to_pandas(
         return correct_dtype
 
     def _converter(
-        dt: DataType, _struct_in_pandas: Optional[str]
+        dt: DataType, _struct_in_pandas: Optional[str], _ndarray_as_list: bool
     ) -> Optional[Callable[[Any], Any]]:
 
         if isinstance(dt, ArrayType):
-            _element_conv = _converter(dt.elementType, _struct_in_pandas)
-            if _element_conv is None:
-                return None
+            _element_conv = _converter(dt.elementType, _struct_in_pandas, _ndarray_as_list)
 
-            def convert_array(value: Any) -> Any:
-                if value is None:
+            if _ndarray_as_list:
+                if _element_conv is None:
+                    _element_conv = lambda x: x  # noqa: E731
+
+                def convert_array_ndarray_as_list(value: Any) -> Any:
+                    if value is None:
+                        return None
+                    else:
+                        # In Arrow Python UDF, ArrayType is converted to `np.ndarray`
+                        # whereas a list is expected.
+                        return [_element_conv(v) for v in value]  # type: ignore[misc]
+
+                return convert_array_ndarray_as_list
+            else:
+                if _element_conv is None:
                     return None
-                elif isinstance(value, np.ndarray):
-                    # `pyarrow.Table.to_pandas` uses `np.ndarray`.
-                    return np.array([_element_conv(v) for v in value])  # type: ignore[misc]
-                else:
-                    assert isinstance(value, list)
-                    # otherwise, `list` should be used.
-                    return [_element_conv(v) for v in value]  # type: ignore[misc]
 
-            return convert_array
+                def convert_array_ndarray_as_ndarray(value: Any) -> Any:
+                    if value is None:
+                        return None
+                    elif isinstance(value, np.ndarray):
+                        # `pyarrow.Table.to_pandas` uses `np.ndarray`.
+                        return np.array([_element_conv(v) for v in value])  # type: ignore[misc]
+                    else:
+                        assert isinstance(value, list)
+                        # otherwise, `list` should be used.
+                        return [_element_conv(v) for v in value]  # type: ignore[misc]
+
+                return convert_array_ndarray_as_ndarray
 
         elif isinstance(dt, MapType):
-            _key_conv = _converter(dt.keyType, _struct_in_pandas) or (lambda x: x)
-            _value_conv = _converter(dt.valueType, _struct_in_pandas) or (lambda x: x)
+            _key_conv = _converter(dt.keyType, _struct_in_pandas, _ndarray_as_list) or (lambda x: x)
+            _value_conv = _converter(dt.valueType, _struct_in_pandas, _ndarray_as_list) or (
+                lambda x: x
+            )
 
             def convert_map(value: Any) -> Any:
                 if value is None:
@@ -621,7 +641,8 @@ def _create_converter_to_pandas(
             dedup_field_names = _dedup_names(field_names)
 
             field_convs = [
-                _converter(f.dataType, _struct_in_pandas) or (lambda x: x) for f in dt.fields
+                _converter(f.dataType, _struct_in_pandas, _ndarray_as_list) or (lambda x: x)
+                for f in dt.fields
             ]
 
             if _struct_in_pandas == "row":
@@ -699,7 +720,9 @@ def _create_converter_to_pandas(
         elif isinstance(dt, UserDefinedType):
             udt: UserDefinedType = dt
 
-            conv = _converter(udt.sqlType(), _struct_in_pandas="row") or (lambda x: x)
+            conv = _converter(udt.sqlType(), _struct_in_pandas="row", _ndarray_as_list=True) or (
+                lambda x: x
+            )
 
             def convert_udt(value: Any) -> Any:
                 if value is None:
@@ -715,7 +738,7 @@ def _create_converter_to_pandas(
         else:
             return None
 
-    conv = _converter(data_type, struct_in_pandas)
+    conv = _converter(data_type, struct_in_pandas, ndarray_as_list)
     if conv is not None:
         return lambda pser: pser.apply(conv)  # type: ignore[return-value]
     else:
diff --git a/python/pyspark/sql/tests/test_arrow_python_udf.py b/python/pyspark/sql/tests/test_arrow_python_udf.py
index c60a7ef648a..0accb0f3cc1 100644
--- a/python/pyspark/sql/tests/test_arrow_python_udf.py
+++ b/python/pyspark/sql/tests/test_arrow_python_udf.py
@@ -62,8 +62,7 @@ class PythonUDFArrowTestsMixin(BaseUDFTestsMixin):
             .first()
         )
 
-        # The input is NumPy array when the optimization is on.
-        self.assertEquals(row[0], "[1 2 3]")
+        self.assertEquals(row[0], "[1, 2, 3]")
         self.assertEquals(row[1], "{'a': 'b'}")
         self.assertEquals(row[2], "Row(col1=1, col2=2)")
 
@@ -92,8 +91,7 @@ class PythonUDFArrowTestsMixin(BaseUDFTestsMixin):
             .first()
         )
 
-        # The input is a NumPy array when the Arrow optimization is on.
-        self.assertEquals(row_true[0], row_none[0])  # "[1 2 3]"
+        self.assertEquals(row_true[0], row_none[0])  # "[1, 2, 3]"
 
         # useArrow=False
         row_false = (
@@ -125,7 +123,7 @@ class PythonUDFArrowTestsMixin(BaseUDFTestsMixin):
         # To verify that Arrow optimization is on
         self.assertEquals(
             df.selectExpr("str_repr(array) AS str_id").first()[0],
-            "[1 2 3]",  # The input is a NumPy array when the Arrow optimization is on
+            "[1, 2, 3]",  # The input is a NumPy array when the Arrow optimization is on
         )
 
         # To verify that a UserDefinedFunction is returned
@@ -134,6 +132,15 @@ class PythonUDFArrowTestsMixin(BaseUDFTestsMixin):
             df.select(str_repr_func("array").alias("str_id")).collect(),
         )
 
+    def test_nested_array_input(self):
+        df = self.spark.range(1).selectExpr("array(array(1, 2), array(3, 4)) as nested_array")
+        self.assertEquals(
+            df.select(
+                udf(lambda x: str(x), returnType="string", useArrow=True)("nested_array")
+            ).first()[0],
+            "[[1, 2], [3, 4]]",
+        )
+
 
 class PythonUDFArrowTests(PythonUDFArrowTestsMixin, ReusedSQLTestCase):
     @classmethod
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index aaf38fc145a..71a7ccd15aa 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -597,12 +597,14 @@ def read_udfs(pickleSer, infile, eval_type):
             struct_in_pandas = (
                 "row" if eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF else "dict"
             )
+            ndarray_as_list = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
             ser = ArrowStreamPandasUDFSerializer(
                 timezone,
                 safecheck,
                 assign_cols_by_name(runner_conf),
                 df_for_struct,
                 struct_in_pandas,
+                ndarray_as_list,
             )
     else:
         ser = BatchedSerializer(CPickleSerializer(), 100)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org