You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2017/09/26 01:54:06 UTC

spark git commit: [SPARK-22106][PYSPARK][SQL] Disable 0-parameter pandas_udf and add doctests

Repository: spark
Updated Branches:
  refs/heads/master ce204780e -> d8e825e3b


[SPARK-22106][PYSPARK][SQL] Disable 0-parameter pandas_udf and add doctests

## What changes were proposed in this pull request?

This change disables the use of 0-parameter pandas_udfs due to the API being overly complex and awkward, and can easily be worked around by using an index column as an input argument.  Also added doctests for pandas_udfs which revealed bugs for handling empty partitions and using the pandas_udf decorator.

## How was this patch tested?

Reworked existing 0-parameter test to verify error is raised, added doctest for pandas_udf, added new tests for empty partition and decorator usage.

Author: Bryan Cutler <cu...@gmail.com>

Closes #19325 from BryanCutler/arrow-pandas_udf-0-param-remove-SPARK-22106.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d8e825e3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d8e825e3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d8e825e3

Branch: refs/heads/master
Commit: d8e825e3bc5fdb8ba00eba431512fa7f771417f1
Parents: ce20478
Author: Bryan Cutler <cu...@gmail.com>
Authored: Tue Sep 26 10:54:00 2017 +0900
Committer: hyukjinkwon <gu...@gmail.com>
Committed: Tue Sep 26 10:54:00 2017 +0900

----------------------------------------------------------------------
 python/pyspark/serializers.py                   | 15 +----
 python/pyspark/sql/functions.py                 | 32 ++++++++---
 python/pyspark/sql/tests.py                     | 59 +++++++++++++++-----
 python/pyspark/worker.py                        | 25 ++++-----
 .../execution/python/ArrowEvalPythonExec.scala  | 10 ++--
 5 files changed, 89 insertions(+), 52 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d8e825e3/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 887c702..7c1fbad 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -216,9 +216,6 @@ class ArrowPandasSerializer(ArrowSerializer):
     Serializes Pandas.Series as Arrow data.
     """
 
-    def __init__(self):
-        super(ArrowPandasSerializer, self).__init__()
-
     def dumps(self, series):
         """
         Make an ArrowRecordBatch from a Pandas Series and serialize. Input is a single series or
@@ -245,16 +242,10 @@ class ArrowPandasSerializer(ArrowSerializer):
 
     def loads(self, obj):
         """
-        Deserialize an ArrowRecordBatch to an Arrow table and return as a list of pandas.Series
-        followed by a dictionary containing length of the loaded batches.
+        Deserialize an ArrowRecordBatch to an Arrow table and return as a list of pandas.Series.
         """
-        import pyarrow as pa
-        reader = pa.RecordBatchFileReader(pa.BufferReader(obj))
-        batches = [reader.get_batch(i) for i in xrange(reader.num_record_batches)]
-        # NOTE: a 0-parameter pandas_udf will produce an empty batch that can have num_rows set
-        num_rows = sum((batch.num_rows for batch in batches))
-        table = pa.Table.from_batches(batches)
-        return [c.to_pandas() for c in table.itercolumns()] + [{"length": num_rows}]
+        table = super(ArrowPandasSerializer, self).loads(obj)
+        return [c.to_pandas() for c in table.itercolumns()]
 
     def __repr__(self):
         return "ArrowPandasSerializer"

http://git-wip-us.apache.org/repos/asf/spark/blob/d8e825e3/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 46e3a85..63e9a83 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2127,6 +2127,10 @@ class UserDefinedFunction(object):
 def _create_udf(f, returnType, vectorized):
 
     def _udf(f, returnType=StringType(), vectorized=vectorized):
+        if vectorized:
+            import inspect
+            if len(inspect.getargspec(f).args) == 0:
+                raise NotImplementedError("0-parameter pandas_udfs are not currently supported")
         udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized)
         return udf_obj._wrapped()
 
@@ -2183,14 +2187,28 @@ def pandas_udf(f=None, returnType=StringType()):
     :param f: python function if used as a standalone function
     :param returnType: a :class:`pyspark.sql.types.DataType` object
 
-    # TODO: doctest
+    >>> from pyspark.sql.types import IntegerType, StringType
+    >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType())
+    >>> @pandas_udf(returnType=StringType())
+    ... def to_upper(s):
+    ...     return s.str.upper()
+    ...
+    >>> @pandas_udf(returnType="integer")
+    ... def add_one(x):
+    ...     return x + 1
+    ...
+    >>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age"))
+    >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\
+    ...     .show() # doctest: +SKIP
+    +----------+--------------+------------+
+    |slen(name)|to_upper(name)|add_one(age)|
+    +----------+--------------+------------+
+    |         8|      JOHN DOE|          22|
+    +----------+--------------+------------+
     """
-    import inspect
-    # If function "f" does not define the optional kwargs, then wrap with a kwargs placeholder
-    if inspect.getargspec(f).keywords is None:
-        return _create_udf(lambda *a, **kwargs: f(*a), returnType=returnType, vectorized=True)
-    else:
-        return _create_udf(f, returnType=returnType, vectorized=True)
+    wrapped_udf = _create_udf(f, returnType=returnType, vectorized=True)
+
+    return wrapped_udf
 
 
 blacklist = ['map', 'since', 'ignore_unicode_prefix']

http://git-wip-us.apache.org/repos/asf/spark/blob/d8e825e3/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 3db8bee..1b3af42 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3256,11 +3256,20 @@ class VectorizedUDFTests(ReusedPySparkTestCase):
 
     def test_vectorized_udf_zero_parameter(self):
         from pyspark.sql.functions import pandas_udf
-        import pandas as pd
-        df = self.spark.range(10)
-        f0 = pandas_udf(lambda **kwargs: pd.Series(1).repeat(kwargs['length']), LongType())
-        res = df.select(f0())
-        self.assertEquals(df.select(lit(1)).collect(), res.collect())
+        error_str = '0-parameter pandas_udfs.*not.*supported'
+        with QuietTest(self.sc):
+            with self.assertRaisesRegexp(NotImplementedError, error_str):
+                pandas_udf(lambda: 1, LongType())
+
+            with self.assertRaisesRegexp(NotImplementedError, error_str):
+                @pandas_udf
+                def zero_no_type():
+                    return 1
+
+            with self.assertRaisesRegexp(NotImplementedError, error_str):
+                @pandas_udf(LongType())
+                def zero_with_type():
+                    return 1
 
     def test_vectorized_udf_datatype_string(self):
         from pyspark.sql.functions import pandas_udf, col
@@ -3308,12 +3317,12 @@ class VectorizedUDFTests(ReusedPySparkTestCase):
         from pyspark.sql.functions import pandas_udf, col
         import pandas as pd
         df = self.spark.range(10)
-        raise_exception = pandas_udf(lambda: pd.Series(1), LongType())
+        raise_exception = pandas_udf(lambda _: pd.Series(1), LongType())
         with QuietTest(self.sc):
             with self.assertRaisesRegexp(
                     Exception,
                     'Result vector from pandas_udf was not the required length'):
-                df.select(raise_exception()).collect()
+                df.select(raise_exception(col('id'))).collect()
 
     def test_vectorized_udf_mix_udf(self):
         from pyspark.sql.functions import pandas_udf, udf, col
@@ -3328,22 +3337,44 @@ class VectorizedUDFTests(ReusedPySparkTestCase):
 
     def test_vectorized_udf_chained(self):
         from pyspark.sql.functions import pandas_udf, col
-        df = self.spark.range(10).toDF('x')
+        df = self.spark.range(10)
         f = pandas_udf(lambda x: x + 1, LongType())
         g = pandas_udf(lambda x: x - 1, LongType())
-        res = df.select(g(f(col('x'))))
+        res = df.select(g(f(col('id'))))
         self.assertEquals(df.collect(), res.collect())
 
     def test_vectorized_udf_wrong_return_type(self):
         from pyspark.sql.functions import pandas_udf, col
-        df = self.spark.range(10).toDF('x')
+        df = self.spark.range(10)
         f = pandas_udf(lambda x: x * 1.0, StringType())
         with QuietTest(self.sc):
-            with self.assertRaisesRegexp(
-                    Exception,
-                    'Invalid.*type.*string'):
-                df.select(f(col('x'))).collect()
+            with self.assertRaisesRegexp(Exception, 'Invalid.*type.*string'):
+                df.select(f(col('id'))).collect()
+
+    def test_vectorized_udf_return_scalar(self):
+        from pyspark.sql.functions import pandas_udf, col
+        df = self.spark.range(10)
+        f = pandas_udf(lambda x: 1.0, DoubleType())
+        with QuietTest(self.sc):
+            with self.assertRaisesRegexp(Exception, 'Return.*type.*pandas_udf.*Series'):
+                df.select(f(col('id'))).collect()
+
+    def test_vectorized_udf_decorator(self):
+        from pyspark.sql.functions import pandas_udf, col
+        df = self.spark.range(10)
 
+        @pandas_udf(returnType=LongType())
+        def identity(x):
+            return x
+        res = df.select(identity(col('id')))
+        self.assertEquals(df.collect(), res.collect())
+
+    def test_vectorized_udf_empty_partition(self):
+        from pyspark.sql.functions import pandas_udf, col
+        df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))
+        f = pandas_udf(lambda x: x, LongType())
+        res = df.select(f(col('id')))
+        self.assertEquals(df.collect(), res.collect())
 
 if __name__ == "__main__":
     from pyspark.sql.tests import *

http://git-wip-us.apache.org/repos/asf/spark/blob/d8e825e3/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 0e35cf7..fd917c4 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -60,12 +60,9 @@ def read_command(serializer, file):
     return command
 
 
-def chain(f, g, eval_type):
+def chain(f, g):
     """chain two functions together """
-    if eval_type == PythonEvalType.SQL_PANDAS_UDF:
-        return lambda *a, **kwargs: g(f(*a, **kwargs), **kwargs)
-    else:
-        return lambda *a: g(f(*a))
+    return lambda *a: g(f(*a))
 
 
 def wrap_udf(f, return_type):
@@ -80,14 +77,14 @@ def wrap_pandas_udf(f, return_type):
     arrow_return_type = toArrowType(return_type)
 
     def verify_result_length(*a):
-        kwargs = a[-1]
-        result = f(*a[:-1], **kwargs)
-        if len(result) != kwargs["length"]:
+        result = f(*a)
+        if not hasattr(result, "__len__"):
+            raise TypeError("Return type of pandas_udf should be a Pandas.Series")
+        if len(result) != len(a[0]):
             raise RuntimeError("Result vector from pandas_udf was not the required length: "
-                               "expected %d, got %d\nUse input vector length or kwargs['length']"
-                               % (kwargs["length"], len(result)))
-        return result, arrow_return_type
-    return lambda *a: verify_result_length(*a)
+                               "expected %d, got %d" % (len(a[0]), len(result)))
+        return result
+    return lambda *a: (verify_result_length(*a), arrow_return_type)
 
 
 def read_single_udf(pickleSer, infile, eval_type):
@@ -99,11 +96,9 @@ def read_single_udf(pickleSer, infile, eval_type):
         if row_func is None:
             row_func = f
         else:
-            row_func = chain(row_func, f, eval_type)
+            row_func = chain(row_func, f)
     # the last returnType will be the return type of UDF
     if eval_type == PythonEvalType.SQL_PANDAS_UDF:
-        # A pandas_udf will take kwargs as the last argument
-        arg_offsets = arg_offsets + [-1]
         return arg_offsets, wrap_pandas_udf(row_func, return_type)
     else:
         return arg_offsets, wrap_udf(row_func, return_type)

http://git-wip-us.apache.org/repos/asf/spark/blob/d8e825e3/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
index f8bdc1e..5e72cd2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
@@ -51,10 +51,12 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
       outputIterator.map(new ArrowPayload(_)), context)
 
     // Verify that the output schema is correct
-    val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex
-      .map { case (attr, i) => attr.withName(s"_$i") })
-    assert(schemaOut.equals(outputRowIterator.schema),
-      s"Invalid schema from pandas_udf: expected $schemaOut, got ${outputRowIterator.schema}")
+    if (outputRowIterator.hasNext) {
+      val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex
+        .map { case (attr, i) => attr.withName(s"_$i") })
+      assert(schemaOut.equals(outputRowIterator.schema),
+        s"Invalid schema from pandas_udf: expected $schemaOut, got ${outputRowIterator.schema}")
+    }
 
     outputRowIterator
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org