You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by cu...@apache.org on 2019/03/07 16:53:04 UTC
[spark] branch master updated: [SPARK-23836][PYTHON] Add support
for StructType return in Scalar Pandas UDF
This is an automated email from the ASF dual-hosted git repository.
cutlerb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new ddc2052 [SPARK-23836][PYTHON] Add support for StructType return in Scalar Pandas UDF
ddc2052 is described below
commit ddc2052ebd247aa2a8dad34fd5c1cd345fa45118
Author: Bryan Cutler <cu...@gmail.com>
AuthorDate: Thu Mar 7 08:52:24 2019 -0800
[SPARK-23836][PYTHON] Add support for StructType return in Scalar Pandas UDF
## What changes were proposed in this pull request?
This change adds support for returning StructType from a scalar Pandas UDF, where the return value of the function is a pandas.DataFrame. Nested structs are not supported and an error will be raised, child types can be any other type currently supported.
## How was this patch tested?
Added additional unit tests to `test_pandas_udf_scalar`
Closes #23900 from BryanCutler/pyspark-support-scalar_udf-StructType-SPARK-23836.
Authored-by: Bryan Cutler <cu...@gmail.com>
Signed-off-by: Bryan Cutler <cu...@gmail.com>
---
python/pyspark/serializers.py | 39 +++++++++--
python/pyspark/sql/functions.py | 12 +++-
python/pyspark/sql/session.py | 3 +-
.../sql/tests/test_pandas_udf_grouped_map.py | 1 +
python/pyspark/sql/tests/test_pandas_udf_scalar.py | 81 +++++++++++++++++++++-
python/pyspark/sql/types.py | 8 ++-
python/pyspark/sql/udf.py | 5 +-
python/pyspark/worker.py | 12 +++-
8 files changed, 149 insertions(+), 12 deletions(-)
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index a2c59fe..0c3c68e 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -64,6 +64,7 @@ if sys.version < '3':
from itertools import izip as zip, imap as map
else:
import pickle
+ basestring = unicode = str
xrange = range
pickle_protocol = pickle.HIGHEST_PROTOCOL
@@ -244,7 +245,7 @@ class ArrowStreamSerializer(Serializer):
return "ArrowStreamSerializer"
-def _create_batch(series, timezone, safecheck):
+def _create_batch(series, timezone, safecheck, assign_cols_by_name):
"""
Create an Arrow record batch from the given pandas.Series or list of Series, with optional type.
@@ -254,6 +255,7 @@ def _create_batch(series, timezone, safecheck):
"""
import decimal
from distutils.version import LooseVersion
+ import pandas as pd
import pyarrow as pa
from pyspark.sql.types import _check_series_convert_timestamps_internal
# Make input conform to [(series1, type1), (series2, type2), ...]
@@ -295,7 +297,34 @@ def _create_batch(series, timezone, safecheck):
raise RuntimeError(error_msg % (s.dtype, t), e)
return array
- arrs = [create_array(s, t) for s, t in series]
+ arrs = []
+ for s, t in series:
+ if t is not None and pa.types.is_struct(t):
+ if not isinstance(s, pd.DataFrame):
+ raise ValueError("A field of type StructType expects a pandas.DataFrame, "
+ "but got: %s" % str(type(s)))
+
+ # Input partition and result pandas.DataFrame empty, make empty Arrays with struct
+ if len(s) == 0 and len(s.columns) == 0:
+ arrs_names = [(pa.array([], type=field.type), field.name) for field in t]
+ # Assign result columns by schema name if user labeled with strings
+ elif assign_cols_by_name and any(isinstance(name, basestring) for name in s.columns):
+ arrs_names = [(create_array(s[field.name], field.type), field.name) for field in t]
+ # Assign result columns by position
+ else:
+ arrs_names = [(create_array(s[s.columns[i]], field.type), field.name)
+ for i, field in enumerate(t)]
+
+ struct_arrs, struct_names = zip(*arrs_names)
+
+ # TODO: from_arrays args switched for v0.9.0, remove when bump minimum pyarrow version
+ if LooseVersion(pa.__version__) < LooseVersion("0.9.0"):
+ arrs.append(pa.StructArray.from_arrays(struct_names, struct_arrs))
+ else:
+ arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names))
+ else:
+ arrs.append(create_array(s, t))
+
return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])
@@ -304,10 +333,11 @@ class ArrowStreamPandasSerializer(Serializer):
Serializes Pandas.Series as Arrow data with Arrow streaming format.
"""
- def __init__(self, timezone, safecheck):
+ def __init__(self, timezone, safecheck, assign_cols_by_name):
super(ArrowStreamPandasSerializer, self).__init__()
self._timezone = timezone
self._safecheck = safecheck
+ self._assign_cols_by_name = assign_cols_by_name
def arrow_to_pandas(self, arrow_column):
from pyspark.sql.types import from_arrow_type, \
@@ -326,7 +356,8 @@ class ArrowStreamPandasSerializer(Serializer):
writer = None
try:
for series in iterator:
- batch = _create_batch(series, self._timezone, self._safecheck)
+ batch = _create_batch(series, self._timezone, self._safecheck,
+ self._assign_cols_by_name)
if writer is None:
write_int(SpecialLengths.START_ARROW_STREAM, stream)
writer = pa.RecordBatchStreamWriter(stream, batch.schema)
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 3c33e2b..a36423e 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2842,8 +2842,9 @@ def pandas_udf(f=None, returnType=None, functionType=None):
A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`.
The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`.
+ If the return type is :class:`StructType`, the returned value should be a `pandas.DataFrame`.
- :class:`MapType`, :class:`StructType` are currently not supported as output types.
+ :class:`MapType`, nested :class:`StructType` are currently not supported as output types.
Scalar UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and
:meth:`pyspark.sql.DataFrame.select`.
@@ -2868,6 +2869,15 @@ def pandas_udf(f=None, returnType=None, functionType=None):
+----------+--------------+------------+
| 8| JOHN DOE| 22|
+----------+--------------+------------+
+ >>> @pandas_udf("first string, last string") # doctest: +SKIP
+ ... def split_expand(n):
+ ... return n.str.split(expand=True)
+ >>> df.select(split_expand("name")).show() # doctest: +SKIP
+ +------------------+
+ |split_expand(name)|
+ +------------------+
+ | [John, Doe]|
+ +------------------+
.. note:: The length of `pandas.Series` within a scalar UDF is not that of the whole input
column, but is the length of an internal batch used for each call to the function.
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index bdf1701..32a2c8a 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -557,8 +557,9 @@ class SparkSession(object):
# Create Arrow record batches
safecheck = self._wrapped._conf.arrowSafeTypeConversion()
+ col_by_name = True # col by name only applies to StructType columns, can't happen here
batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)],
- timezone, safecheck)
+ timezone, safecheck, col_by_name)
for pdf_slice in pdf_slices]
# Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing)
diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py
index a0a2535..f7684d3 100644
--- a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py
+++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py
@@ -273,6 +273,7 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
StructField('map', MapType(StringType(), IntegerType())),
StructField('arr_ts', ArrayType(TimestampType())),
StructField('null', NullType()),
+ StructField('struct', StructType([StructField('l', LongType())])),
]
# TODO: Remove this if-statement once minimum pyarrow version is 0.10.0
diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py
index 28ef98d..28b6db2 100644
--- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py
+++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py
@@ -23,13 +23,16 @@ import tempfile
import time
import unittest
+if sys.version >= '3':
+ unicode = str
+
from datetime import date, datetime
from decimal import Decimal
from distutils.version import LooseVersion
from pyspark.rdd import PythonEvalType
from pyspark.sql import Column
-from pyspark.sql.functions import array, col, expr, lit, sum, udf, pandas_udf
+from pyspark.sql.functions import array, col, expr, lit, sum, struct, udf, pandas_udf
from pyspark.sql.types import Row
from pyspark.sql.types import *
from pyspark.sql.utils import AnalysisException
@@ -265,6 +268,64 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
result = df.select(array_f(col('array')))
self.assertEquals(df.collect(), result.collect())
+ def test_vectorized_udf_struct_type(self):
+ import pandas as pd
+
+ df = self.spark.range(10)
+ return_type = StructType([
+ StructField('id', LongType()),
+ StructField('str', StringType())])
+
+ def func(id):
+ return pd.DataFrame({'id': id, 'str': id.apply(unicode)})
+
+ f = pandas_udf(func, returnType=return_type)
+
+ expected = df.select(struct(col('id'), col('id').cast('string').alias('str'))
+ .alias('struct')).collect()
+
+ actual = df.select(f(col('id')).alias('struct')).collect()
+ self.assertEqual(expected, actual)
+
+ g = pandas_udf(func, 'id: long, str: string')
+ actual = df.select(g(col('id')).alias('struct')).collect()
+ self.assertEqual(expected, actual)
+
+ def test_vectorized_udf_struct_complex(self):
+ import pandas as pd
+
+ df = self.spark.range(10)
+ return_type = StructType([
+ StructField('ts', TimestampType()),
+ StructField('arr', ArrayType(LongType()))])
+
+ @pandas_udf(returnType=return_type)
+ def f(id):
+ return pd.DataFrame({'ts': id.apply(lambda i: pd.Timestamp(i)),
+ 'arr': id.apply(lambda i: [i, i + 1])})
+
+ actual = df.withColumn('f', f(col('id'))).collect()
+ for i, row in enumerate(actual):
+ id, f = row
+ self.assertEqual(i, id)
+ self.assertEqual(pd.Timestamp(i).to_pydatetime(), f[0])
+ self.assertListEqual([i, i + 1], f[1])
+
+ def test_vectorized_udf_nested_struct(self):
+ nested_type = StructType([
+ StructField('id', IntegerType()),
+ StructField('nested', StructType([
+ StructField('foo', StringType()),
+ StructField('bar', FloatType())
+ ]))
+ ])
+
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(
+ Exception,
+ 'Invalid returnType with scalar Pandas UDFs'):
+ pandas_udf(lambda x: x, returnType=nested_type)
+
def test_vectorized_udf_complex(self):
df = self.spark.range(10).select(
col('id').cast('int').alias('a'),
@@ -331,6 +392,20 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
res = df.select(f(col('id')))
self.assertEquals(df.collect(), res.collect())
+ def test_vectorized_udf_struct_with_empty_partition(self):
+ df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))\
+ .withColumn('name', lit('John Doe'))
+
+ @pandas_udf("first string, last string")
+ def split_expand(n):
+ return n.str.split(expand=True)
+
+ result = df.select(split_expand('name')).collect()
+ self.assertEqual(1, len(result))
+ row = result[0]
+ self.assertEqual('John', row[0]['first'])
+ self.assertEqual('Doe', row[0]['last'])
+
def test_vectorized_udf_varargs(self):
df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))
f = pandas_udf(lambda *v: v[0], LongType())
@@ -343,6 +418,10 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
NotImplementedError,
'Invalid returnType.*scalar Pandas UDF.*MapType'):
pandas_udf(lambda x: x, MapType(StringType(), IntegerType()))
+ with self.assertRaisesRegexp(
+ NotImplementedError,
+ 'Invalid returnType.*scalar Pandas UDF.*ArrayType.StructType'):
+ pandas_udf(lambda x: x, ArrayType(StructType([StructField('a', IntegerType())])))
def test_vectorized_udf_dates(self):
schema = StructType().add("idx", LongType()).add("date", DateType())
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 348cb5b1..d87f0f9 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1613,9 +1613,15 @@ def to_arrow_type(dt):
# Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read
arrow_type = pa.timestamp('us', tz='UTC')
elif type(dt) == ArrayType:
- if type(dt.elementType) == TimestampType:
+ if type(dt.elementType) in [StructType, TimestampType]:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
arrow_type = pa.list_(to_arrow_type(dt.elementType))
+ elif type(dt) == StructType:
+ if any(type(field.dataType) == StructType for field in dt):
+ raise TypeError("Nested StructType not supported in conversion to Arrow")
+ fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
+ for field in dt]
+ arrow_type = pa.struct(fields)
else:
raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
return arrow_type
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 58f4e0d..275abe9 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -123,7 +123,7 @@ class UserDefinedFunction(object):
elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
if isinstance(self._returnType_placeholder, StructType):
try:
- to_arrow_schema(self._returnType_placeholder)
+ to_arrow_type(self._returnType_placeholder)
except TypeError:
raise NotImplementedError(
"Invalid returnType with grouped map Pandas UDFs: "
@@ -133,6 +133,9 @@ class UserDefinedFunction(object):
"UDFs: returnType must be a StructType.")
elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
try:
+ # StructType is not yet allowed as a return type, explicitly check here to fail fast
+ if isinstance(self._returnType_placeholder, StructType):
+ raise TypeError
to_arrow_type(self._returnType_placeholder)
except TypeError:
raise NotImplementedError(
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 01934a0..0e9b6d6 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -39,7 +39,7 @@ from pyspark.rdd import PythonEvalType
from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
BatchedSerializer, ArrowStreamPandasSerializer
-from pyspark.sql.types import to_arrow_type
+from pyspark.sql.types import to_arrow_type, StructType
from pyspark.util import _get_argspec, fail_on_stopiteration
from pyspark import shuffle
@@ -90,8 +90,9 @@ def wrap_scalar_pandas_udf(f, return_type):
def verify_result_length(*a):
result = f(*a)
if not hasattr(result, "__len__"):
+ pd_type = "Pandas.DataFrame" if type(return_type) == StructType else "Pandas.Series"
raise TypeError("Return type of the user-defined function should be "
- "Pandas.Series, but is {}".format(type(result)))
+ "{}, but is {}".format(pd_type, type(result)))
if len(result) != len(a[0]):
raise RuntimeError("Result vector from pandas_udf was not the required length: "
"expected %d, got %d" % (len(a[0]), len(result)))
@@ -254,7 +255,12 @@ def read_udfs(pickleSer, infile, eval_type):
timezone = runner_conf.get("spark.sql.session.timeZone", None)
safecheck = runner_conf.get("spark.sql.execution.pandas.arrowSafeTypeConversion",
"false").lower() == 'true'
- ser = ArrowStreamPandasSerializer(timezone, safecheck)
+ # NOTE: this is duplicated from wrap_grouped_map_pandas_udf
+ assign_cols_by_name = runner_conf.get(
+ "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\
+ .lower() == "true"
+
+ ser = ArrowStreamPandasSerializer(timezone, safecheck, assign_cols_by_name)
else:
ser = BatchedSerializer(PickleSerializer(), 100)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org