You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by cu...@apache.org on 2019/03/25 18:26:41 UTC

[spark] branch master updated: [SPARK-27240][PYTHON] Use pandas DataFrame for struct type argument in Scalar Pandas UDF.

This is an automated email from the ASF dual-hosted git repository.

cutlerb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 594be7a  [SPARK-27240][PYTHON] Use pandas DataFrame for struct type argument in Scalar Pandas UDF.
594be7a is described below

commit 594be7a911584a05f875f217ac39af9a5eeff049
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Mon Mar 25 11:26:09 2019 -0700

    [SPARK-27240][PYTHON] Use pandas DataFrame for struct type argument in Scalar Pandas UDF.
    
    ## What changes were proposed in this pull request?
    
    Now that we support returning pandas DataFrame for struct type in Scalar Pandas UDF.
    
    If we chain another Pandas UDF after the Scalar Pandas UDF returning pandas DataFrame, the argument of the chained UDF will be pandas DataFrame, but currently we don't support pandas DataFrame as an argument of Scalar Pandas UDF. That means there is an inconsistency between the chained UDF and the single UDF.
    
    We should support taking pandas DataFrame for struct type argument in Scalar Pandas UDF to be consistent.
    Currently pyarrow >=0.11 is supported.
    
    ## How was this patch tested?
    
    Modified and added some tests.
    
    Closes #24177 from ueshin/issues/SPARK-27240/structtype_argument.
    
    Authored-by: Takuya UESHIN <ue...@databricks.com>
    Signed-off-by: Bryan Cutler <cu...@gmail.com>
---
 python/pyspark/serializers.py                      | 29 +++++++++++++++----
 python/pyspark/sql/tests/test_pandas_udf_scalar.py | 33 ++++++++++++++++++++++
 python/pyspark/sql/types.py                        | 10 +++++++
 python/pyspark/worker.py                           |  6 +++-
 4 files changed, 72 insertions(+), 6 deletions(-)

diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 58f7552..ed419db 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -260,11 +260,10 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
         self._safecheck = safecheck
         self._assign_cols_by_name = assign_cols_by_name
 
-    def arrow_to_pandas(self, arrow_column):
-        from pyspark.sql.types import from_arrow_type, \
-            _arrow_column_to_pandas, _check_series_localize_timestamps
+    def arrow_to_pandas(self, arrow_column, data_type):
+        from pyspark.sql.types import _arrow_column_to_pandas, _check_series_localize_timestamps
 
-        s = _arrow_column_to_pandas(arrow_column, from_arrow_type(arrow_column.type))
+        s = _arrow_column_to_pandas(arrow_column, data_type)
         s = _check_series_localize_timestamps(s, self._timezone)
         return s
 
@@ -366,8 +365,10 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
         """
         batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
         import pyarrow as pa
+        from pyspark.sql.types import from_arrow_type
         for batch in batches:
-            yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()]
+            yield [self.arrow_to_pandas(c, from_arrow_type(c.type))
+                   for c in pa.Table.from_batches([batch]).itercolumns()]
 
     def __repr__(self):
         return "ArrowStreamPandasSerializer"
@@ -378,6 +379,24 @@ class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
     Serializer used by Python worker to evaluate Pandas UDFs
     """
 
+    def __init__(self, timezone, safecheck, assign_cols_by_name, df_for_struct=False):
+        super(ArrowStreamPandasUDFSerializer, self) \
+            .__init__(timezone, safecheck, assign_cols_by_name)
+        self._df_for_struct = df_for_struct
+
+    def arrow_to_pandas(self, arrow_column, data_type):
+        from pyspark.sql.types import StructType, \
+            _arrow_column_to_pandas, _check_dataframe_localize_timestamps
+
+        if self._df_for_struct and type(data_type) == StructType:
+            import pandas as pd
+            series = [_arrow_column_to_pandas(column, field.dataType).rename(field.name)
+                      for column, field in zip(arrow_column.flatten(), data_type)]
+            s = _check_dataframe_localize_timestamps(pd.concat(series, axis=1), self._timezone)
+        else:
+            s = super(ArrowStreamPandasUDFSerializer, self).arrow_to_pandas(arrow_column, data_type)
+        return s
+
     def dump_stream(self, iterator, stream):
         """
         Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent.
diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py
index 28b6db2..7df918b 100644
--- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py
+++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py
@@ -270,6 +270,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
 
     def test_vectorized_udf_struct_type(self):
         import pandas as pd
+        import pyarrow as pa
 
         df = self.spark.range(10)
         return_type = StructType([
@@ -291,6 +292,18 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
         actual = df.select(g(col('id')).alias('struct')).collect()
         self.assertEqual(expected, actual)
 
+        struct_f = pandas_udf(lambda x: x, return_type)
+        actual = df.select(struct_f(struct(col('id'), col('id').cast('string').alias('str'))))
+        if LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
+            with QuietTest(self.sc):
+                from py4j.protocol import Py4JJavaError
+                with self.assertRaisesRegexp(
+                        Py4JJavaError,
+                        'Unsupported type in conversion from Arrow'):
+                    self.assertEqual(expected, actual.collect())
+        else:
+            self.assertEqual(expected, actual.collect())
+
     def test_vectorized_udf_struct_complex(self):
         import pandas as pd
 
@@ -363,6 +376,26 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
         res = df.select(g(f(col('id'))))
         self.assertEquals(df.collect(), res.collect())
 
+    def test_vectorized_udf_chained_struct_type(self):
+        import pandas as pd
+
+        df = self.spark.range(10)
+        return_type = StructType([
+            StructField('id', LongType()),
+            StructField('str', StringType())])
+
+        @pandas_udf(return_type)
+        def f(id):
+            return pd.DataFrame({'id': id, 'str': id.apply(unicode)})
+
+        g = pandas_udf(lambda x: x, return_type)
+
+        expected = df.select(struct(col('id'), col('id').cast('string').alias('str'))
+                             .alias('struct')).collect()
+
+        actual = df.select(g(f(col('id'))).alias('struct')).collect()
+        self.assertEqual(expected, actual)
+
     def test_vectorized_udf_wrong_return_type(self):
         with QuietTest(self.sc):
             with self.assertRaisesRegexp(
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index d87f0f9..21595e5 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1674,6 +1674,16 @@ def from_arrow_type(at):
         if types.is_timestamp(at.value_type):
             raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
         spark_type = ArrayType(from_arrow_type(at.value_type))
+    elif types.is_struct(at):
+        # TODO: remove version check once minimum pyarrow version is 0.10.0
+        if LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
+            raise TypeError("Unsupported type in conversion from Arrow: " + str(at) +
+                            "\nPlease install pyarrow >= 0.10.0 for StructType support.")
+        if any(types.is_struct(field.type) for field in at):
+            raise TypeError("Nested StructType not supported in conversion from Arrow: " + str(at))
+        return StructType(
+            [StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
+             for field in at])
     else:
         raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
     return spark_type
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 0e3bef6..16257be 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -253,7 +253,11 @@ def read_udfs(pickleSer, infile, eval_type):
             "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\
             .lower() == "true"
 
-        ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name)
+        # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of
+        # pandas Series. See SPARK-27240.
+        df_for_struct = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF
+        ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name,
+                                             df_for_struct)
     else:
         ser = BatchedSerializer(PickleSerializer(), 100)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org