You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/06/21 09:22:19 UTC
[spark] branch master updated: [SPARK-43903][PYTHON][CONNECT] Improve ArrayType input support in Arrow Python UDF
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new ac1e2223105 [SPARK-43903][PYTHON][CONNECT] Improve ArrayType input support in Arrow Python UDF
ac1e2223105 is described below
commit ac1e22231055d7e59eec5dd8c6a807252aab8b7f
Author: Xinrong Meng <xi...@apache.org>
AuthorDate: Wed Jun 21 17:22:00 2023 +0800
[SPARK-43903][PYTHON][CONNECT] Improve ArrayType input support in Arrow Python UDF
### What changes were proposed in this pull request?
Improve ArrayType input support in Arrow Python UDF.
Previously, ArrayType is mapped to a 'np.array'; now it is mapped to a `list` following Pickled Python UDF.
### Why are the changes needed?
Reach parity with Pickled Python UDF.
### Does this PR introduce _any_ user-facing change?
Yes.
FROM
```py
>>> df = spark.range(1).selectExpr("array(array(1, 2), array(3, 4)) as nested_array")
>>> df.select(udf(lambda x: str(x), returnType='string', useArrow=True)("nested_array")).first()
Row(<lambda>(nested_array)='[array([1, 2], dtype=int32) array([3, 4], dtype=int32)]')
```
TO
```py
>>> df = spark.range(1).selectExpr("array(array(1, 2), array(3, 4)) as nested_array"
>>> df.select(udf(lambda x: str(x), returnType='string', useArrow=True)("nested_array")).first()
Row(<lambda>(nested_array)='[[1, 2], [3, 4]]')
```
### How was this patch tested?
Unit tests.
Closes #41603 from xinrong-meng/ndarr.
Authored-by: Xinrong Meng <xi...@apache.org>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
python/pyspark/sql/pandas/serializers.py | 9 ++--
python/pyspark/sql/pandas/types.py | 61 ++++++++++++++++-------
python/pyspark/sql/tests/test_arrow_python_udf.py | 17 +++++--
python/pyspark/worker.py | 2 +
4 files changed, 62 insertions(+), 27 deletions(-)
diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py
index 12d0bee88ad..307fcc33752 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -172,7 +172,7 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
self._timezone = timezone
self._safecheck = safecheck
- def arrow_to_pandas(self, arrow_column, struct_in_pandas="dict"):
+ def arrow_to_pandas(self, arrow_column, struct_in_pandas="dict", ndarray_as_list=False):
# If the given column is a date type column, creates a series of datetime.date directly
# instead of creating datetime64[ns] as intermediate data to avoid overflow caused by
# datetime64[ns] type handling.
@@ -186,6 +186,7 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
timezone=self._timezone,
struct_in_pandas=struct_in_pandas,
error_on_duplicated_field_names=True,
+ ndarray_as_list=ndarray_as_list,
)
return converter(s)
@@ -317,11 +318,13 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
assign_cols_by_name,
df_for_struct=False,
struct_in_pandas="dict",
+ ndarray_as_list=False,
):
super(ArrowStreamPandasUDFSerializer, self).__init__(timezone, safecheck)
self._assign_cols_by_name = assign_cols_by_name
self._df_for_struct = df_for_struct
self._struct_in_pandas = struct_in_pandas
+ self._ndarray_as_list = ndarray_as_list
def arrow_to_pandas(self, arrow_column):
import pyarrow.types as types
@@ -331,14 +334,14 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
series = [
super(ArrowStreamPandasUDFSerializer, self)
- .arrow_to_pandas(column, self._struct_in_pandas)
+ .arrow_to_pandas(column, self._struct_in_pandas, self._ndarray_as_list)
.rename(field.name)
for column, field in zip(arrow_column.flatten(), arrow_column.type)
]
s = pd.concat(series, axis=1)
else:
s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(
- arrow_column, self._struct_in_pandas
+ arrow_column, self._struct_in_pandas, self._ndarray_as_list
)
return s
diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py
index 757deff6130..53362047604 100644
--- a/python/pyspark/sql/pandas/types.py
+++ b/python/pyspark/sql/pandas/types.py
@@ -494,6 +494,7 @@ def _create_converter_to_pandas(
struct_in_pandas: Optional[str] = None,
error_on_duplicated_field_names: bool = True,
timestamp_utc_localized: bool = True,
+ ndarray_as_list: bool = False,
) -> Callable[["pd.Series"], "pd.Series"]:
"""
Create a converter of pandas Series that is created from Spark's Python objects,
@@ -520,6 +521,8 @@ def _create_converter_to_pandas(
Whether the timestamp values are localized to UTC or not.
The timestamp values from Arrow are localized to UTC,
whereas the ones from `df.collect()` are localized to the local timezone.
+ ndarray_as_list : bool, optional
+ Whether `np.ndarray` is converted to a list or not (default ``False``).
Returns
-------
@@ -569,30 +572,47 @@ def _create_converter_to_pandas(
return correct_dtype
def _converter(
- dt: DataType, _struct_in_pandas: Optional[str]
+ dt: DataType, _struct_in_pandas: Optional[str], _ndarray_as_list: bool
) -> Optional[Callable[[Any], Any]]:
if isinstance(dt, ArrayType):
- _element_conv = _converter(dt.elementType, _struct_in_pandas)
- if _element_conv is None:
- return None
+ _element_conv = _converter(dt.elementType, _struct_in_pandas, _ndarray_as_list)
- def convert_array(value: Any) -> Any:
- if value is None:
+ if _ndarray_as_list:
+ if _element_conv is None:
+ _element_conv = lambda x: x # noqa: E731
+
+ def convert_array_ndarray_as_list(value: Any) -> Any:
+ if value is None:
+ return None
+ else:
+ # In Arrow Python UDF, ArrayType is converted to `np.ndarray`
+ # whereas a list is expected.
+ return [_element_conv(v) for v in value] # type: ignore[misc]
+
+ return convert_array_ndarray_as_list
+ else:
+ if _element_conv is None:
return None
- elif isinstance(value, np.ndarray):
- # `pyarrow.Table.to_pandas` uses `np.ndarray`.
- return np.array([_element_conv(v) for v in value]) # type: ignore[misc]
- else:
- assert isinstance(value, list)
- # otherwise, `list` should be used.
- return [_element_conv(v) for v in value] # type: ignore[misc]
- return convert_array
+ def convert_array_ndarray_as_ndarray(value: Any) -> Any:
+ if value is None:
+ return None
+ elif isinstance(value, np.ndarray):
+ # `pyarrow.Table.to_pandas` uses `np.ndarray`.
+ return np.array([_element_conv(v) for v in value]) # type: ignore[misc]
+ else:
+ assert isinstance(value, list)
+ # otherwise, `list` should be used.
+ return [_element_conv(v) for v in value] # type: ignore[misc]
+
+ return convert_array_ndarray_as_ndarray
elif isinstance(dt, MapType):
- _key_conv = _converter(dt.keyType, _struct_in_pandas) or (lambda x: x)
- _value_conv = _converter(dt.valueType, _struct_in_pandas) or (lambda x: x)
+ _key_conv = _converter(dt.keyType, _struct_in_pandas, _ndarray_as_list) or (lambda x: x)
+ _value_conv = _converter(dt.valueType, _struct_in_pandas, _ndarray_as_list) or (
+ lambda x: x
+ )
def convert_map(value: Any) -> Any:
if value is None:
@@ -621,7 +641,8 @@ def _create_converter_to_pandas(
dedup_field_names = _dedup_names(field_names)
field_convs = [
- _converter(f.dataType, _struct_in_pandas) or (lambda x: x) for f in dt.fields
+ _converter(f.dataType, _struct_in_pandas, _ndarray_as_list) or (lambda x: x)
+ for f in dt.fields
]
if _struct_in_pandas == "row":
@@ -699,7 +720,9 @@ def _create_converter_to_pandas(
elif isinstance(dt, UserDefinedType):
udt: UserDefinedType = dt
- conv = _converter(udt.sqlType(), _struct_in_pandas="row") or (lambda x: x)
+ conv = _converter(udt.sqlType(), _struct_in_pandas="row", _ndarray_as_list=True) or (
+ lambda x: x
+ )
def convert_udt(value: Any) -> Any:
if value is None:
@@ -715,7 +738,7 @@ def _create_converter_to_pandas(
else:
return None
- conv = _converter(data_type, struct_in_pandas)
+ conv = _converter(data_type, struct_in_pandas, ndarray_as_list)
if conv is not None:
return lambda pser: pser.apply(conv) # type: ignore[return-value]
else:
diff --git a/python/pyspark/sql/tests/test_arrow_python_udf.py b/python/pyspark/sql/tests/test_arrow_python_udf.py
index c60a7ef648a..0accb0f3cc1 100644
--- a/python/pyspark/sql/tests/test_arrow_python_udf.py
+++ b/python/pyspark/sql/tests/test_arrow_python_udf.py
@@ -62,8 +62,7 @@ class PythonUDFArrowTestsMixin(BaseUDFTestsMixin):
.first()
)
- # The input is NumPy array when the optimization is on.
- self.assertEquals(row[0], "[1 2 3]")
+ self.assertEquals(row[0], "[1, 2, 3]")
self.assertEquals(row[1], "{'a': 'b'}")
self.assertEquals(row[2], "Row(col1=1, col2=2)")
@@ -92,8 +91,7 @@ class PythonUDFArrowTestsMixin(BaseUDFTestsMixin):
.first()
)
- # The input is a NumPy array when the Arrow optimization is on.
- self.assertEquals(row_true[0], row_none[0]) # "[1 2 3]"
+ self.assertEquals(row_true[0], row_none[0]) # "[1, 2, 3]"
# useArrow=False
row_false = (
@@ -125,7 +123,7 @@ class PythonUDFArrowTestsMixin(BaseUDFTestsMixin):
# To verify that Arrow optimization is on
self.assertEquals(
df.selectExpr("str_repr(array) AS str_id").first()[0],
- "[1 2 3]", # The input is a NumPy array when the Arrow optimization is on
+ "[1, 2, 3]", # The input is a NumPy array when the Arrow optimization is on
)
# To verify that a UserDefinedFunction is returned
@@ -134,6 +132,15 @@ class PythonUDFArrowTestsMixin(BaseUDFTestsMixin):
df.select(str_repr_func("array").alias("str_id")).collect(),
)
+ def test_nested_array_input(self):
+ df = self.spark.range(1).selectExpr("array(array(1, 2), array(3, 4)) as nested_array")
+ self.assertEquals(
+ df.select(
+ udf(lambda x: str(x), returnType="string", useArrow=True)("nested_array")
+ ).first()[0],
+ "[[1, 2], [3, 4]]",
+ )
+
class PythonUDFArrowTests(PythonUDFArrowTestsMixin, ReusedSQLTestCase):
@classmethod
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index aaf38fc145a..71a7ccd15aa 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -597,12 +597,14 @@ def read_udfs(pickleSer, infile, eval_type):
struct_in_pandas = (
"row" if eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF else "dict"
)
+ ndarray_as_list = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
ser = ArrowStreamPandasUDFSerializer(
timezone,
safecheck,
assign_cols_by_name(runner_conf),
df_for_struct,
struct_in_pandas,
+ ndarray_as_list,
)
else:
ser = BatchedSerializer(CPickleSerializer(), 100)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org