You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by cu...@apache.org on 2019/03/07 16:53:04 UTC

[spark] branch master updated: [SPARK-23836][PYTHON] Add support for StructType return in Scalar Pandas UDF

This is an automated email from the ASF dual-hosted git repository.

cutlerb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ddc2052  [SPARK-23836][PYTHON] Add support for StructType return in Scalar Pandas UDF
ddc2052 is described below

commit ddc2052ebd247aa2a8dad34fd5c1cd345fa45118
Author: Bryan Cutler <cu...@gmail.com>
AuthorDate: Thu Mar 7 08:52:24 2019 -0800

    [SPARK-23836][PYTHON] Add support for StructType return in Scalar Pandas UDF
    
    ## What changes were proposed in this pull request?
    
    This change adds support for returning StructType from a scalar Pandas UDF, where the return value of the function is a pandas.DataFrame. Nested structs are not supported and an error will be raised, child types can be any other type currently supported.
    
    ## How was this patch tested?
    
    Added additional unit tests to `test_pandas_udf_scalar`
    
    Closes #23900 from BryanCutler/pyspark-support-scalar_udf-StructType-SPARK-23836.
    
    Authored-by: Bryan Cutler <cu...@gmail.com>
    Signed-off-by: Bryan Cutler <cu...@gmail.com>
---
 python/pyspark/serializers.py                      | 39 +++++++++--
 python/pyspark/sql/functions.py                    | 12 +++-
 python/pyspark/sql/session.py                      |  3 +-
 .../sql/tests/test_pandas_udf_grouped_map.py       |  1 +
 python/pyspark/sql/tests/test_pandas_udf_scalar.py | 81 +++++++++++++++++++++-
 python/pyspark/sql/types.py                        |  8 ++-
 python/pyspark/sql/udf.py                          |  5 +-
 python/pyspark/worker.py                           | 12 +++-
 8 files changed, 149 insertions(+), 12 deletions(-)

diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index a2c59fe..0c3c68e 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -64,6 +64,7 @@ if sys.version < '3':
     from itertools import izip as zip, imap as map
 else:
     import pickle
+    basestring = unicode = str
     xrange = range
 pickle_protocol = pickle.HIGHEST_PROTOCOL
 
@@ -244,7 +245,7 @@ class ArrowStreamSerializer(Serializer):
         return "ArrowStreamSerializer"
 
 
-def _create_batch(series, timezone, safecheck):
+def _create_batch(series, timezone, safecheck, assign_cols_by_name):
     """
     Create an Arrow record batch from the given pandas.Series or list of Series, with optional type.
 
@@ -254,6 +255,7 @@ def _create_batch(series, timezone, safecheck):
     """
     import decimal
     from distutils.version import LooseVersion
+    import pandas as pd
     import pyarrow as pa
     from pyspark.sql.types import _check_series_convert_timestamps_internal
     # Make input conform to [(series1, type1), (series2, type2), ...]
@@ -295,7 +297,34 @@ def _create_batch(series, timezone, safecheck):
             raise RuntimeError(error_msg % (s.dtype, t), e)
         return array
 
-    arrs = [create_array(s, t) for s, t in series]
+    arrs = []
+    for s, t in series:
+        if t is not None and pa.types.is_struct(t):
+            if not isinstance(s, pd.DataFrame):
+                raise ValueError("A field of type StructType expects a pandas.DataFrame, "
+                                 "but got: %s" % str(type(s)))
+
+            # Input partition and result pandas.DataFrame empty, make empty Arrays with struct
+            if len(s) == 0 and len(s.columns) == 0:
+                arrs_names = [(pa.array([], type=field.type), field.name) for field in t]
+            # Assign result columns by schema name if user labeled with strings
+            elif assign_cols_by_name and any(isinstance(name, basestring) for name in s.columns):
+                arrs_names = [(create_array(s[field.name], field.type), field.name) for field in t]
+            # Assign result columns by  position
+            else:
+                arrs_names = [(create_array(s[s.columns[i]], field.type), field.name)
+                              for i, field in enumerate(t)]
+
+            struct_arrs, struct_names = zip(*arrs_names)
+
+            # TODO: from_arrays args switched for v0.9.0, remove when bump minimum pyarrow version
+            if LooseVersion(pa.__version__) < LooseVersion("0.9.0"):
+                arrs.append(pa.StructArray.from_arrays(struct_names, struct_arrs))
+            else:
+                arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names))
+        else:
+            arrs.append(create_array(s, t))
+
     return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])
 
 
@@ -304,10 +333,11 @@ class ArrowStreamPandasSerializer(Serializer):
     Serializes Pandas.Series as Arrow data with Arrow streaming format.
     """
 
-    def __init__(self, timezone, safecheck):
+    def __init__(self, timezone, safecheck, assign_cols_by_name):
         super(ArrowStreamPandasSerializer, self).__init__()
         self._timezone = timezone
         self._safecheck = safecheck
+        self._assign_cols_by_name = assign_cols_by_name
 
     def arrow_to_pandas(self, arrow_column):
         from pyspark.sql.types import from_arrow_type, \
@@ -326,7 +356,8 @@ class ArrowStreamPandasSerializer(Serializer):
         writer = None
         try:
             for series in iterator:
-                batch = _create_batch(series, self._timezone, self._safecheck)
+                batch = _create_batch(series, self._timezone, self._safecheck,
+                                      self._assign_cols_by_name)
                 if writer is None:
                     write_int(SpecialLengths.START_ARROW_STREAM, stream)
                     writer = pa.RecordBatchStreamWriter(stream, batch.schema)
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 3c33e2b..a36423e 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2842,8 +2842,9 @@ def pandas_udf(f=None, returnType=None, functionType=None):
 
        A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`.
        The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`.
+       If the return type is :class:`StructType`, the returned value should be a `pandas.DataFrame`.
 
-       :class:`MapType`, :class:`StructType` are currently not supported as output types.
+       :class:`MapType`, nested :class:`StructType` are currently not supported as output types.
 
        Scalar UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and
        :meth:`pyspark.sql.DataFrame.select`.
@@ -2868,6 +2869,15 @@ def pandas_udf(f=None, returnType=None, functionType=None):
        +----------+--------------+------------+
        |         8|      JOHN DOE|          22|
        +----------+--------------+------------+
+       >>> @pandas_udf("first string, last string")  # doctest: +SKIP
+       ... def split_expand(n):
+       ...     return n.str.split(expand=True)
+       >>> df.select(split_expand("name")).show()  # doctest: +SKIP
+       +------------------+
+       |split_expand(name)|
+       +------------------+
+       |       [John, Doe]|
+       +------------------+
 
        .. note:: The length of `pandas.Series` within a scalar UDF is not that of the whole input
            column, but is the length of an internal batch used for each call to the function.
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index bdf1701..32a2c8a 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -557,8 +557,9 @@ class SparkSession(object):
 
         # Create Arrow record batches
         safecheck = self._wrapped._conf.arrowSafeTypeConversion()
+        col_by_name = True  # col by name only applies to StructType columns, can't happen here
         batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)],
-                                 timezone, safecheck)
+                                 timezone, safecheck, col_by_name)
                    for pdf_slice in pdf_slices]
 
         # Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing)
diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py
index a0a2535..f7684d3 100644
--- a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py
+++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py
@@ -273,6 +273,7 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
             StructField('map', MapType(StringType(), IntegerType())),
             StructField('arr_ts', ArrayType(TimestampType())),
             StructField('null', NullType()),
+            StructField('struct', StructType([StructField('l', LongType())])),
         ]
 
         # TODO: Remove this if-statement once minimum pyarrow version is 0.10.0
diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py
index 28ef98d..28b6db2 100644
--- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py
+++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py
@@ -23,13 +23,16 @@ import tempfile
 import time
 import unittest
 
+if sys.version >= '3':
+    unicode = str
+
 from datetime import date, datetime
 from decimal import Decimal
 from distutils.version import LooseVersion
 
 from pyspark.rdd import PythonEvalType
 from pyspark.sql import Column
-from pyspark.sql.functions import array, col, expr, lit, sum, udf, pandas_udf
+from pyspark.sql.functions import array, col, expr, lit, sum, struct, udf, pandas_udf
 from pyspark.sql.types import Row
 from pyspark.sql.types import *
 from pyspark.sql.utils import AnalysisException
@@ -265,6 +268,64 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
         result = df.select(array_f(col('array')))
         self.assertEquals(df.collect(), result.collect())
 
+    def test_vectorized_udf_struct_type(self):
+        import pandas as pd
+
+        df = self.spark.range(10)
+        return_type = StructType([
+            StructField('id', LongType()),
+            StructField('str', StringType())])
+
+        def func(id):
+            return pd.DataFrame({'id': id, 'str': id.apply(unicode)})
+
+        f = pandas_udf(func, returnType=return_type)
+
+        expected = df.select(struct(col('id'), col('id').cast('string').alias('str'))
+                             .alias('struct')).collect()
+
+        actual = df.select(f(col('id')).alias('struct')).collect()
+        self.assertEqual(expected, actual)
+
+        g = pandas_udf(func, 'id: long, str: string')
+        actual = df.select(g(col('id')).alias('struct')).collect()
+        self.assertEqual(expected, actual)
+
+    def test_vectorized_udf_struct_complex(self):
+        import pandas as pd
+
+        df = self.spark.range(10)
+        return_type = StructType([
+            StructField('ts', TimestampType()),
+            StructField('arr', ArrayType(LongType()))])
+
+        @pandas_udf(returnType=return_type)
+        def f(id):
+            return pd.DataFrame({'ts': id.apply(lambda i: pd.Timestamp(i)),
+                                 'arr': id.apply(lambda i: [i, i + 1])})
+
+        actual = df.withColumn('f', f(col('id'))).collect()
+        for i, row in enumerate(actual):
+            id, f = row
+            self.assertEqual(i, id)
+            self.assertEqual(pd.Timestamp(i).to_pydatetime(), f[0])
+            self.assertListEqual([i, i + 1], f[1])
+
+    def test_vectorized_udf_nested_struct(self):
+        nested_type = StructType([
+            StructField('id', IntegerType()),
+            StructField('nested', StructType([
+                StructField('foo', StringType()),
+                StructField('bar', FloatType())
+            ]))
+        ])
+
+        with QuietTest(self.sc):
+            with self.assertRaisesRegexp(
+                    Exception,
+                    'Invalid returnType with scalar Pandas UDFs'):
+                pandas_udf(lambda x: x, returnType=nested_type)
+
     def test_vectorized_udf_complex(self):
         df = self.spark.range(10).select(
             col('id').cast('int').alias('a'),
@@ -331,6 +392,20 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
         res = df.select(f(col('id')))
         self.assertEquals(df.collect(), res.collect())
 
+    def test_vectorized_udf_struct_with_empty_partition(self):
+        df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))\
+            .withColumn('name', lit('John Doe'))
+
+        @pandas_udf("first string, last string")
+        def split_expand(n):
+            return n.str.split(expand=True)
+
+        result = df.select(split_expand('name')).collect()
+        self.assertEqual(1, len(result))
+        row = result[0]
+        self.assertEqual('John', row[0]['first'])
+        self.assertEqual('Doe', row[0]['last'])
+
     def test_vectorized_udf_varargs(self):
         df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))
         f = pandas_udf(lambda *v: v[0], LongType())
@@ -343,6 +418,10 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
                     NotImplementedError,
                     'Invalid returnType.*scalar Pandas UDF.*MapType'):
                 pandas_udf(lambda x: x, MapType(StringType(), IntegerType()))
+            with self.assertRaisesRegexp(
+                    NotImplementedError,
+                    'Invalid returnType.*scalar Pandas UDF.*ArrayType.StructType'):
+                pandas_udf(lambda x: x, ArrayType(StructType([StructField('a', IntegerType())])))
 
     def test_vectorized_udf_dates(self):
         schema = StructType().add("idx", LongType()).add("date", DateType())
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 348cb5b1..d87f0f9 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1613,9 +1613,15 @@ def to_arrow_type(dt):
         # Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read
         arrow_type = pa.timestamp('us', tz='UTC')
     elif type(dt) == ArrayType:
-        if type(dt.elementType) == TimestampType:
+        if type(dt.elementType) in [StructType, TimestampType]:
             raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
         arrow_type = pa.list_(to_arrow_type(dt.elementType))
+    elif type(dt) == StructType:
+        if any(type(field.dataType) == StructType for field in dt):
+            raise TypeError("Nested StructType not supported in conversion to Arrow")
+        fields = [pa.field(field.name, to_arrow_type(field.dataType), nullable=field.nullable)
+                  for field in dt]
+        arrow_type = pa.struct(fields)
     else:
         raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
     return arrow_type
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 58f4e0d..275abe9 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -123,7 +123,7 @@ class UserDefinedFunction(object):
         elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
             if isinstance(self._returnType_placeholder, StructType):
                 try:
-                    to_arrow_schema(self._returnType_placeholder)
+                    to_arrow_type(self._returnType_placeholder)
                 except TypeError:
                     raise NotImplementedError(
                         "Invalid returnType with grouped map Pandas UDFs: "
@@ -133,6 +133,9 @@ class UserDefinedFunction(object):
                                 "UDFs: returnType must be a StructType.")
         elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
             try:
+                # StructType is not yet allowed as a return type, explicitly check here to fail fast
+                if isinstance(self._returnType_placeholder, StructType):
+                    raise TypeError
                 to_arrow_type(self._returnType_placeholder)
             except TypeError:
                 raise NotImplementedError(
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 01934a0..0e9b6d6 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -39,7 +39,7 @@ from pyspark.rdd import PythonEvalType
 from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \
     write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
     BatchedSerializer, ArrowStreamPandasSerializer
-from pyspark.sql.types import to_arrow_type
+from pyspark.sql.types import to_arrow_type, StructType
 from pyspark.util import _get_argspec, fail_on_stopiteration
 from pyspark import shuffle
 
@@ -90,8 +90,9 @@ def wrap_scalar_pandas_udf(f, return_type):
     def verify_result_length(*a):
         result = f(*a)
         if not hasattr(result, "__len__"):
+            pd_type = "Pandas.DataFrame" if type(return_type) == StructType else "Pandas.Series"
             raise TypeError("Return type of the user-defined function should be "
-                            "Pandas.Series, but is {}".format(type(result)))
+                            "{}, but is {}".format(pd_type, type(result)))
         if len(result) != len(a[0]):
             raise RuntimeError("Result vector from pandas_udf was not the required length: "
                                "expected %d, got %d" % (len(a[0]), len(result)))
@@ -254,7 +255,12 @@ def read_udfs(pickleSer, infile, eval_type):
         timezone = runner_conf.get("spark.sql.session.timeZone", None)
         safecheck = runner_conf.get("spark.sql.execution.pandas.arrowSafeTypeConversion",
                                     "false").lower() == 'true'
-        ser = ArrowStreamPandasSerializer(timezone, safecheck)
+        # NOTE: this is duplicated from wrap_grouped_map_pandas_udf
+        assign_cols_by_name = runner_conf.get(
+            "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\
+            .lower() == "true"
+
+        ser = ArrowStreamPandasSerializer(timezone, safecheck, assign_cols_by_name)
     else:
         ser = BatchedSerializer(PickleSerializer(), 100)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org