You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by xi...@apache.org on 2023/06/29 18:46:19 UTC

[spark] branch master updated: [SPARK-44150][PYTHON][CONNECT] Explicit Arrow casting for mismatched return type in Arrow Python UDF

This is an automated email from the ASF dual-hosted git repository.

xinrong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6e56cfeaca8 [SPARK-44150][PYTHON][CONNECT] Explicit Arrow casting for mismatched return type in Arrow Python UDF
6e56cfeaca8 is described below

commit 6e56cfeaca884b1ccfaa8524c70f12f118bc840c
Author: Xinrong Meng <xi...@apache.org>
AuthorDate: Thu Jun 29 11:46:06 2023 -0700

    [SPARK-44150][PYTHON][CONNECT] Explicit Arrow casting for mismatched return type in Arrow Python UDF
    
    ### What changes were proposed in this pull request?
    Explicit Arrow casting for the mismatched return type of Arrow Python UDF.
    
    ### Why are the changes needed?
    A more standardized and coherent type coercion.
    
    Please refer to https://github.com/apache/spark/pull/41706 for a comprehensive comparison between type coercion rules of Arrow and Pickle(used by the default Python UDF) separately.
    
    See more at [[Design] Type-coercion in Arrow Python UDFs](https://docs.google.com/document/d/e/2PACX-1vTEGElOZfhl9NfgbBw4CTrlm-8F_xQCAKNOXouz-7mg5vYobS7lCGUsGkDZxPY0wV5YkgoZmkYlxccU/pub).
    
    ### Does this PR introduce _any_ user-facing change?
    Yes.
    
    FROM
    ```py
    >>> df = spark.createDataFrame(['1', '2'], schema='string')
    df.select(pandas_udf(lambda x: x, 'int')('value')).show()
    >>> df.select(pandas_udf(lambda x: x, 'int')('value')).show()
    ...
    org.apache.spark.api.python.PythonException: Traceback (most recent call last):
    ...
    pyarrow.lib.ArrowInvalid: Could not convert '1' with type str: tried to convert to int32
    ```
    
    TO
    ```py
    >>> df = spark.createDataFrame(['1', '2'], schema='string')
    >>> df.select(pandas_udf(lambda x: x, 'int')('value')).show()
    +---------------+
    |<lambda>(value)|
    +---------------+
    |              1|
    |              2|
    +---------------+
    ```
    ### How was this patch tested?
    Unit tests.
    
    Closes #41503 from xinrong-meng/type_coersion.
    
    Authored-by: Xinrong Meng <xi...@apache.org>
    Signed-off-by: Xinrong Meng <xi...@apache.org>
---
 python/pyspark/sql/pandas/serializers.py          | 30 ++++++++++++++---
 python/pyspark/sql/tests/test_arrow_python_udf.py | 39 +++++++++++++++++++++++
 python/pyspark/worker.py                          |  3 ++
 3 files changed, 67 insertions(+), 5 deletions(-)

diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py
index 307fcc33752..a99eda9cbea 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -190,7 +190,7 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
         )
         return converter(s)
 
-    def _create_array(self, series, arrow_type, spark_type=None):
+    def _create_array(self, series, arrow_type, spark_type=None, arrow_cast=False):
         """
         Create an Arrow Array from the given pandas.Series and optional type.
 
@@ -202,6 +202,9 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
             If None, pyarrow's inferred type will be used
         spark_type : DataType, optional
             If None, spark type converted from arrow_type will be used
+        arrow_cast: bool, optional
+            Whether to apply Arrow casting when the user-specified return type mismatches the
+            actual return values.
 
         Returns
         -------
@@ -226,7 +229,12 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
         else:
             mask = series.isnull()
         try:
-            return pa.Array.from_pandas(series, mask=mask, type=arrow_type, safe=self._safecheck)
+            if arrow_cast:
+                return pa.Array.from_pandas(series, mask=mask, type=arrow_type).cast(
+                    target_type=arrow_type, safe=self._safecheck
+                )
+            else:
+                return pa.Array.from_pandas(series, mask=mask, safe=self._safecheck)
         except TypeError as e:
             error_msg = (
                 "Exception thrown when converting pandas.Series (%s) "
@@ -319,12 +327,14 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
         df_for_struct=False,
         struct_in_pandas="dict",
         ndarray_as_list=False,
+        arrow_cast=False,
     ):
         super(ArrowStreamPandasUDFSerializer, self).__init__(timezone, safecheck)
         self._assign_cols_by_name = assign_cols_by_name
         self._df_for_struct = df_for_struct
         self._struct_in_pandas = struct_in_pandas
         self._ndarray_as_list = ndarray_as_list
+        self._arrow_cast = arrow_cast
 
     def arrow_to_pandas(self, arrow_column):
         import pyarrow.types as types
@@ -386,7 +396,13 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
                 # Assign result columns by schema name if user labeled with strings
                 elif self._assign_cols_by_name and any(isinstance(name, str) for name in s.columns):
                     arrs_names = [
-                        (self._create_array(s[field.name], field.type), field.name) for field in t
+                        (
+                            self._create_array(
+                                s[field.name], field.type, arrow_cast=self._arrow_cast
+                            ),
+                            field.name,
+                        )
+                        for field in t
                     ]
                 # Assign result columns by  position
                 else:
@@ -394,7 +410,11 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
                         # the selected series has name '1', so we rename it to field.name
                         # as the name is used by _create_array to provide a meaningful error message
                         (
-                            self._create_array(s[s.columns[i]].rename(field.name), field.type),
+                            self._create_array(
+                                s[s.columns[i]].rename(field.name),
+                                field.type,
+                                arrow_cast=self._arrow_cast,
+                            ),
                             field.name,
                         )
                         for i, field in enumerate(t)
@@ -403,7 +423,7 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
                 struct_arrs, struct_names = zip(*arrs_names)
                 arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names))
             else:
-                arrs.append(self._create_array(s, t))
+                arrs.append(self._create_array(s, t, arrow_cast=self._arrow_cast))
 
         return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))])
 
diff --git a/python/pyspark/sql/tests/test_arrow_python_udf.py b/python/pyspark/sql/tests/test_arrow_python_udf.py
index 0accb0f3cc1..264ea0b901f 100644
--- a/python/pyspark/sql/tests/test_arrow_python_udf.py
+++ b/python/pyspark/sql/tests/test_arrow_python_udf.py
@@ -17,6 +17,8 @@
 
 import unittest
 
+from pyspark.errors import PythonException
+from pyspark.sql import Row
 from pyspark.sql.functions import udf
 from pyspark.sql.tests.test_udf import BaseUDFTestsMixin
 from pyspark.testing.sqlutils import (
@@ -141,6 +143,43 @@ class PythonUDFArrowTestsMixin(BaseUDFTestsMixin):
             "[[1, 2], [3, 4]]",
         )
 
+    def test_type_coercion_string_to_numeric(self):
+        df_int_value = self.spark.createDataFrame(["1", "2"], schema="string")
+        df_floating_value = self.spark.createDataFrame(["1.1", "2.2"], schema="string")
+
+        int_ddl_types = ["tinyint", "smallint", "int", "bigint"]
+        floating_ddl_types = ["double", "float"]
+
+        for ddl_type in int_ddl_types:
+            # df_int_value
+            res = df_int_value.select(udf(lambda x: x, ddl_type)("value").alias("res"))
+            self.assertEquals(res.collect(), [Row(res=1), Row(res=2)])
+            self.assertEquals(res.dtypes[0][1], ddl_type)
+
+        floating_results = [
+            [Row(res=1.1), Row(res=2.2)],
+            [Row(res=1.100000023841858), Row(res=2.200000047683716)],
+        ]
+        for ddl_type, floating_res in zip(floating_ddl_types, floating_results):
+            # df_int_value
+            res = df_int_value.select(udf(lambda x: x, ddl_type)("value").alias("res"))
+            self.assertEquals(res.collect(), [Row(res=1.0), Row(res=2.0)])
+            self.assertEquals(res.dtypes[0][1], ddl_type)
+            # df_floating_value
+            res = df_floating_value.select(udf(lambda x: x, ddl_type)("value").alias("res"))
+            self.assertEquals(res.collect(), floating_res)
+            self.assertEquals(res.dtypes[0][1], ddl_type)
+
+        # invalid
+        with self.assertRaises(PythonException):
+            df_floating_value.select(udf(lambda x: x, "int")("value").alias("res")).collect()
+
+        with self.assertRaises(PythonException):
+            df_int_value.select(udf(lambda x: x, "decimal")("value").alias("res")).collect()
+
+        with self.assertRaises(PythonException):
+            df_floating_value.select(udf(lambda x: x, "decimal")("value").alias("res")).collect()
+
 
 class PythonUDFArrowTests(PythonUDFArrowTestsMixin, ReusedSQLTestCase):
     @classmethod
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 71a7ccd15aa..577286a7357 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -598,6 +598,8 @@ def read_udfs(pickleSer, infile, eval_type):
                 "row" if eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF else "dict"
             )
             ndarray_as_list = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
+            # Arrow-optimized Python UDF uses explicit Arrow cast for type coercion
+            arrow_cast = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
             ser = ArrowStreamPandasUDFSerializer(
                 timezone,
                 safecheck,
@@ -605,6 +607,7 @@ def read_udfs(pickleSer, infile, eval_type):
                 df_for_struct,
                 struct_in_pandas,
                 ndarray_as_list,
+                arrow_cast,
             )
     else:
         ser = BatchedSerializer(CPickleSerializer(), 100)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org