You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by xi...@apache.org on 2023/06/06 22:48:39 UTC
[spark] branch master updated: [SPARK-43893][PYTHON][CONNECT] Non-atomic data type support in Arrow-optimized Python UDF
This is an automated email from the ASF dual-hosted git repository.
xinrong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 94098853592 [SPARK-43893][PYTHON][CONNECT] Non-atomic data type support in Arrow-optimized Python UDF
94098853592 is described below
commit 94098853592b524f52e9a340166b96ddeda4e898
Author: Xinrong Meng <xi...@apache.org>
AuthorDate: Tue Jun 6 15:48:14 2023 -0700
[SPARK-43893][PYTHON][CONNECT] Non-atomic data type support in Arrow-optimized Python UDF
### What changes were proposed in this pull request?
Support non-atomic data types in input and output of Arrow-optimized Python UDF.
Non-atomic data types refer to: ArrayType, MapType, and StructType.
### Why are the changes needed?
Parity with pickled Python UDFs.
### Does this PR introduce _any_ user-facing change?
Non-atomic data types are accepted as both input and output of Arrow-optimized Python UDF.
For example,
```py
>>> df = spark.range(1).selectExpr("struct(1, struct('John', 30, ('value', 10))) as nested_struct")
>>> df.select(udf(lambda x: str(x))("nested_struct")).first()
Row(<lambda>(nested_struct)="Row(col1=1, col2=Row(col1='John', col2=30, col3=Row(col1='value', col2=10)))")
```
### How was this patch tested?
Unit tests.
Closes #41321 from xinrong-meng/arrow_udf_struct.
Authored-by: Xinrong Meng <xi...@apache.org>
Signed-off-by: Xinrong Meng <xi...@apache.org>
---
python/pyspark/sql/pandas/serializers.py | 22 ++++++++---
python/pyspark/sql/tests/test_arrow_python_udf.py | 17 ++++-----
python/pyspark/sql/tests/test_udf.py | 45 +++++++++++++++++++++++
python/pyspark/sql/udf.py | 15 +-------
python/pyspark/worker.py | 13 +++++--
5 files changed, 79 insertions(+), 33 deletions(-)
diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py
index 84471143367..12d0bee88ad 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -172,7 +172,7 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
self._timezone = timezone
self._safecheck = safecheck
- def arrow_to_pandas(self, arrow_column):
+ def arrow_to_pandas(self, arrow_column, struct_in_pandas="dict"):
# If the given column is a date type column, creates a series of datetime.date directly
# instead of creating datetime64[ns] as intermediate data to avoid overflow caused by
# datetime64[ns] type handling.
@@ -184,7 +184,7 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
data_type=from_arrow_type(arrow_column.type, prefer_timestamp_ntz=True),
nullable=True,
timezone=self._timezone,
- struct_in_pandas="dict",
+ struct_in_pandas=struct_in_pandas,
error_on_duplicated_field_names=True,
)
return converter(s)
@@ -310,10 +310,18 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
Serializer used by Python worker to evaluate Pandas UDFs
"""
- def __init__(self, timezone, safecheck, assign_cols_by_name, df_for_struct=False):
+ def __init__(
+ self,
+ timezone,
+ safecheck,
+ assign_cols_by_name,
+ df_for_struct=False,
+ struct_in_pandas="dict",
+ ):
super(ArrowStreamPandasUDFSerializer, self).__init__(timezone, safecheck)
self._assign_cols_by_name = assign_cols_by_name
self._df_for_struct = df_for_struct
+ self._struct_in_pandas = struct_in_pandas
def arrow_to_pandas(self, arrow_column):
import pyarrow.types as types
@@ -323,13 +331,15 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
series = [
super(ArrowStreamPandasUDFSerializer, self)
- .arrow_to_pandas(column)
+ .arrow_to_pandas(column, self._struct_in_pandas)
.rename(field.name)
for column, field in zip(arrow_column.flatten(), arrow_column.type)
]
s = pd.concat(series, axis=1)
else:
- s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column)
+ s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(
+ arrow_column, self._struct_in_pandas
+ )
return s
def _create_batch(self, series):
@@ -360,7 +370,7 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
arrs = []
for s, t in series:
- if t is not None and pa.types.is_struct(t):
+ if self._struct_in_pandas == "dict" and t is not None and pa.types.is_struct(t):
if not isinstance(s, pd.DataFrame):
raise PySparkValueError(
"A field of type StructType expects a pandas.DataFrame, "
diff --git a/python/pyspark/sql/tests/test_arrow_python_udf.py b/python/pyspark/sql/tests/test_arrow_python_udf.py
index 3266168f290..c60a7ef648a 100644
--- a/python/pyspark/sql/tests/test_arrow_python_udf.py
+++ b/python/pyspark/sql/tests/test_arrow_python_udf.py
@@ -45,23 +45,19 @@ class PythonUDFArrowTestsMixin(BaseUDFTestsMixin):
def test_register_java_udaf(self):
super(PythonUDFArrowTests, self).test_register_java_udaf()
- @unittest.skip("Struct input types are not supported with Arrow optimization")
- def test_udf_input_serialization_valuecompare_disabled(self):
- super(PythonUDFArrowTests, self).test_udf_input_serialization_valuecompare_disabled()
-
- def test_nested_input_error(self):
- with self.assertRaisesRegexp(Exception, "[NotImplementedError]"):
- self.spark.range(1).selectExpr("struct(1, 2) as struct").select(
- udf(lambda x: x)("struct")
- ).collect()
+ # TODO(SPARK-43903): Standardize ArrayType conversion for Python UDF
+ @unittest.skip("Inconsistent ArrayType conversion with/without Arrow.")
+ def test_nested_array(self):
+ super(PythonUDFArrowTests, self).test_nested_array()
def test_complex_input_types(self):
row = (
self.spark.range(1)
- .selectExpr("array(1, 2, 3) as array", "map('a', 'b') as map")
+ .selectExpr("array(1, 2, 3) as array", "map('a', 'b') as map", "struct(1, 2) as struct")
.select(
udf(lambda x: str(x))("array"),
udf(lambda x: str(x))("map"),
+ udf(lambda x: str(x))("struct"),
)
.first()
)
@@ -69,6 +65,7 @@ class PythonUDFArrowTestsMixin(BaseUDFTestsMixin):
# The input is NumPy array when the optimization is on.
self.assertEquals(row[0], "[1 2 3]")
self.assertEquals(row[1], "{'a': 'b'}")
+ self.assertEquals(row[2], "Row(col1=1, col2=2)")
def test_use_arrow(self):
# useArrow=True
diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py
index 3e33040cc70..8ffcb5e05a2 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -837,6 +837,51 @@ class BaseUDFTestsMixin(object):
len(self.spark.range(10).select(udf(lambda x: x, DoubleType())(rand())).collect()), 10
)
+ def test_nested_struct(self):
+ df = self.spark.range(1).selectExpr(
+ "struct(1, struct('John', 30, ('value', 10))) as nested_struct"
+ )
+ # Input
+ row = df.select(udf(lambda x: str(x))("nested_struct")).first()
+ self.assertEquals(
+ row[0], "Row(col1=1, col2=Row(col1='John', col2=30, col3=Row(col1='value', col2=10)))"
+ )
+ # Output
+ row = df.select(udf(lambda x: x, returnType=df.dtypes[0][1])("nested_struct")).first()
+ self.assertEquals(
+ row[0], Row(col1=1, col2=Row(col1="John", col2=30, col3=Row(col1="value", col2=10)))
+ )
+
+ def test_nested_map(self):
+ df = self.spark.range(1).selectExpr("map('a', map('b', 'c')) as nested_map")
+ # Input
+ row = df.select(udf(lambda x: str(x))("nested_map")).first()
+ self.assertEquals(row[0], "{'a': {'b': 'c'}}")
+ # Output
+
+ @udf(returnType=df.dtypes[0][1])
+ def f(x):
+ x["a"]["b"] = "d"
+ return x
+
+ row = df.select(f("nested_map")).first()
+ self.assertEquals(row[0], {"a": {"b": "d"}})
+
+ def test_nested_array(self):
+ df = self.spark.range(1).selectExpr("array(array(1, 2), array(3, 4)) as nested_array")
+ # Input
+ row = df.select(udf(lambda x: str(x))("nested_array")).first()
+ self.assertEquals(row[0], "[[1, 2], [3, 4]]")
+ # Output
+
+ @udf(returnType=df.dtypes[0][1])
+ def f(x):
+ x.append([4, 5])
+ return x
+
+ row = df.select(f("nested_array")).first()
+ self.assertEquals(row[0], [[1, 2], [3, 4], [4, 5]])
+
class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase):
@classmethod
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 87d53266edf..c6171ffece9 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -32,10 +32,8 @@ from pyspark.profiler import Profiler
from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType
from pyspark.sql.column import Column, _to_java_column, _to_java_expr, _to_seq
from pyspark.sql.types import (
- ArrayType,
BinaryType,
DataType,
- MapType,
StringType,
StructType,
_parse_datatype_string,
@@ -129,18 +127,12 @@ def _create_py_udf(
else useArrow
)
regular_udf = _create_udf(f, returnType, PythonEvalType.SQL_BATCHED_UDF)
- return_type = regular_udf.returnType
try:
is_func_with_args = len(getfullargspec(f).args) > 0
except TypeError:
is_func_with_args = False
- is_output_atomic_type = (
- not isinstance(return_type, StructType)
- and not isinstance(return_type, MapType)
- and not isinstance(return_type, ArrayType)
- )
if is_arrow_enabled:
- if is_output_atomic_type and is_func_with_args:
+ if is_func_with_args:
return _create_arrow_py_udf(regular_udf)
else:
warnings.warn(
@@ -175,11 +167,6 @@ def _create_arrow_py_udf(regular_udf): # type: ignore
result_func = lambda r: bytes(r) if r is not None else r # noqa: E731
def vectorized_udf(*args: pd.Series) -> pd.Series:
- if any(map(lambda arg: isinstance(arg, pd.DataFrame), args)):
- raise PySparkNotImplementedError(
- error_class="UNSUPPORTED_WITH_ARROW_OPTIMIZATION",
- message_parameters={"feature": "Struct input type"},
- )
return pd.Series(result_func(f(*a)) for a in zip(*args))
# Regular UDFs can take callable instances too.
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 9bd8df077b6..06f0d1dc37f 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -511,13 +511,20 @@ def read_udfs(pickleSer, infile, eval_type):
# Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of
# pandas Series. See SPARK-27240.
df_for_struct = (
- eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
- or eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF
+ eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF
or eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
)
+ # Arrow-optimized Python UDF takes a struct type argument as a Row
+ struct_in_pandas = (
+ "row" if eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF else "dict"
+ )
ser = ArrowStreamPandasUDFSerializer(
- timezone, safecheck, assign_cols_by_name(runner_conf), df_for_struct
+ timezone,
+ safecheck,
+ assign_cols_by_name(runner_conf),
+ df_for_struct,
+ struct_in_pandas,
)
else:
ser = BatchedSerializer(CPickleSerializer(), 100)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org