You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2017/09/26 01:54:06 UTC
spark git commit: [SPARK-22106][PYSPARK][SQL] Disable 0-parameter
pandas_udf and add doctests
Repository: spark
Updated Branches:
refs/heads/master ce204780e -> d8e825e3b
[SPARK-22106][PYSPARK][SQL] Disable 0-parameter pandas_udf and add doctests
## What changes were proposed in this pull request?
This change disables the use of 0-parameter pandas_udfs due to the API being overly complex and awkward, and can easily be worked around by using an index column as an input argument. Also added doctests for pandas_udfs which revealed bugs for handling empty partitions and using the pandas_udf decorator.
## How was this patch tested?
Reworked existing 0-parameter test to verify error is raised, added doctest for pandas_udf, added new tests for empty partition and decorator usage.
Author: Bryan Cutler <cu...@gmail.com>
Closes #19325 from BryanCutler/arrow-pandas_udf-0-param-remove-SPARK-22106.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d8e825e3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d8e825e3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d8e825e3
Branch: refs/heads/master
Commit: d8e825e3bc5fdb8ba00eba431512fa7f771417f1
Parents: ce20478
Author: Bryan Cutler <cu...@gmail.com>
Authored: Tue Sep 26 10:54:00 2017 +0900
Committer: hyukjinkwon <gu...@gmail.com>
Committed: Tue Sep 26 10:54:00 2017 +0900
----------------------------------------------------------------------
python/pyspark/serializers.py | 15 +----
python/pyspark/sql/functions.py | 32 ++++++++---
python/pyspark/sql/tests.py | 59 +++++++++++++++-----
python/pyspark/worker.py | 25 ++++-----
.../execution/python/ArrowEvalPythonExec.scala | 10 ++--
5 files changed, 89 insertions(+), 52 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/d8e825e3/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 887c702..7c1fbad 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -216,9 +216,6 @@ class ArrowPandasSerializer(ArrowSerializer):
Serializes Pandas.Series as Arrow data.
"""
- def __init__(self):
- super(ArrowPandasSerializer, self).__init__()
-
def dumps(self, series):
"""
Make an ArrowRecordBatch from a Pandas Series and serialize. Input is a single series or
@@ -245,16 +242,10 @@ class ArrowPandasSerializer(ArrowSerializer):
def loads(self, obj):
"""
- Deserialize an ArrowRecordBatch to an Arrow table and return as a list of pandas.Series
- followed by a dictionary containing length of the loaded batches.
+ Deserialize an ArrowRecordBatch to an Arrow table and return as a list of pandas.Series.
"""
- import pyarrow as pa
- reader = pa.RecordBatchFileReader(pa.BufferReader(obj))
- batches = [reader.get_batch(i) for i in xrange(reader.num_record_batches)]
- # NOTE: a 0-parameter pandas_udf will produce an empty batch that can have num_rows set
- num_rows = sum((batch.num_rows for batch in batches))
- table = pa.Table.from_batches(batches)
- return [c.to_pandas() for c in table.itercolumns()] + [{"length": num_rows}]
+ table = super(ArrowPandasSerializer, self).loads(obj)
+ return [c.to_pandas() for c in table.itercolumns()]
def __repr__(self):
return "ArrowPandasSerializer"
http://git-wip-us.apache.org/repos/asf/spark/blob/d8e825e3/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 46e3a85..63e9a83 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2127,6 +2127,10 @@ class UserDefinedFunction(object):
def _create_udf(f, returnType, vectorized):
def _udf(f, returnType=StringType(), vectorized=vectorized):
+ if vectorized:
+ import inspect
+ if len(inspect.getargspec(f).args) == 0:
+ raise NotImplementedError("0-parameter pandas_udfs are not currently supported")
udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized)
return udf_obj._wrapped()
@@ -2183,14 +2187,28 @@ def pandas_udf(f=None, returnType=StringType()):
:param f: python function if used as a standalone function
:param returnType: a :class:`pyspark.sql.types.DataType` object
- # TODO: doctest
+ >>> from pyspark.sql.types import IntegerType, StringType
+ >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType())
+ >>> @pandas_udf(returnType=StringType())
+ ... def to_upper(s):
+ ... return s.str.upper()
+ ...
+ >>> @pandas_udf(returnType="integer")
+ ... def add_one(x):
+ ... return x + 1
+ ...
+ >>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age"))
+ >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\
+ ... .show() # doctest: +SKIP
+ +----------+--------------+------------+
+ |slen(name)|to_upper(name)|add_one(age)|
+ +----------+--------------+------------+
+ | 8| JOHN DOE| 22|
+ +----------+--------------+------------+
"""
- import inspect
- # If function "f" does not define the optional kwargs, then wrap with a kwargs placeholder
- if inspect.getargspec(f).keywords is None:
- return _create_udf(lambda *a, **kwargs: f(*a), returnType=returnType, vectorized=True)
- else:
- return _create_udf(f, returnType=returnType, vectorized=True)
+ wrapped_udf = _create_udf(f, returnType=returnType, vectorized=True)
+
+ return wrapped_udf
blacklist = ['map', 'since', 'ignore_unicode_prefix']
http://git-wip-us.apache.org/repos/asf/spark/blob/d8e825e3/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 3db8bee..1b3af42 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3256,11 +3256,20 @@ class VectorizedUDFTests(ReusedPySparkTestCase):
def test_vectorized_udf_zero_parameter(self):
from pyspark.sql.functions import pandas_udf
- import pandas as pd
- df = self.spark.range(10)
- f0 = pandas_udf(lambda **kwargs: pd.Series(1).repeat(kwargs['length']), LongType())
- res = df.select(f0())
- self.assertEquals(df.select(lit(1)).collect(), res.collect())
+ error_str = '0-parameter pandas_udfs.*not.*supported'
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(NotImplementedError, error_str):
+ pandas_udf(lambda: 1, LongType())
+
+ with self.assertRaisesRegexp(NotImplementedError, error_str):
+ @pandas_udf
+ def zero_no_type():
+ return 1
+
+ with self.assertRaisesRegexp(NotImplementedError, error_str):
+ @pandas_udf(LongType())
+ def zero_with_type():
+ return 1
def test_vectorized_udf_datatype_string(self):
from pyspark.sql.functions import pandas_udf, col
@@ -3308,12 +3317,12 @@ class VectorizedUDFTests(ReusedPySparkTestCase):
from pyspark.sql.functions import pandas_udf, col
import pandas as pd
df = self.spark.range(10)
- raise_exception = pandas_udf(lambda: pd.Series(1), LongType())
+ raise_exception = pandas_udf(lambda _: pd.Series(1), LongType())
with QuietTest(self.sc):
with self.assertRaisesRegexp(
Exception,
'Result vector from pandas_udf was not the required length'):
- df.select(raise_exception()).collect()
+ df.select(raise_exception(col('id'))).collect()
def test_vectorized_udf_mix_udf(self):
from pyspark.sql.functions import pandas_udf, udf, col
@@ -3328,22 +3337,44 @@ class VectorizedUDFTests(ReusedPySparkTestCase):
def test_vectorized_udf_chained(self):
from pyspark.sql.functions import pandas_udf, col
- df = self.spark.range(10).toDF('x')
+ df = self.spark.range(10)
f = pandas_udf(lambda x: x + 1, LongType())
g = pandas_udf(lambda x: x - 1, LongType())
- res = df.select(g(f(col('x'))))
+ res = df.select(g(f(col('id'))))
self.assertEquals(df.collect(), res.collect())
def test_vectorized_udf_wrong_return_type(self):
from pyspark.sql.functions import pandas_udf, col
- df = self.spark.range(10).toDF('x')
+ df = self.spark.range(10)
f = pandas_udf(lambda x: x * 1.0, StringType())
with QuietTest(self.sc):
- with self.assertRaisesRegexp(
- Exception,
- 'Invalid.*type.*string'):
- df.select(f(col('x'))).collect()
+ with self.assertRaisesRegexp(Exception, 'Invalid.*type.*string'):
+ df.select(f(col('id'))).collect()
+
+ def test_vectorized_udf_return_scalar(self):
+ from pyspark.sql.functions import pandas_udf, col
+ df = self.spark.range(10)
+ f = pandas_udf(lambda x: 1.0, DoubleType())
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(Exception, 'Return.*type.*pandas_udf.*Series'):
+ df.select(f(col('id'))).collect()
+
+ def test_vectorized_udf_decorator(self):
+ from pyspark.sql.functions import pandas_udf, col
+ df = self.spark.range(10)
+ @pandas_udf(returnType=LongType())
+ def identity(x):
+ return x
+ res = df.select(identity(col('id')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_empty_partition(self):
+ from pyspark.sql.functions import pandas_udf, col
+ df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))
+ f = pandas_udf(lambda x: x, LongType())
+ res = df.select(f(col('id')))
+ self.assertEquals(df.collect(), res.collect())
if __name__ == "__main__":
from pyspark.sql.tests import *
http://git-wip-us.apache.org/repos/asf/spark/blob/d8e825e3/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 0e35cf7..fd917c4 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -60,12 +60,9 @@ def read_command(serializer, file):
return command
-def chain(f, g, eval_type):
+def chain(f, g):
"""chain two functions together """
- if eval_type == PythonEvalType.SQL_PANDAS_UDF:
- return lambda *a, **kwargs: g(f(*a, **kwargs), **kwargs)
- else:
- return lambda *a: g(f(*a))
+ return lambda *a: g(f(*a))
def wrap_udf(f, return_type):
@@ -80,14 +77,14 @@ def wrap_pandas_udf(f, return_type):
arrow_return_type = toArrowType(return_type)
def verify_result_length(*a):
- kwargs = a[-1]
- result = f(*a[:-1], **kwargs)
- if len(result) != kwargs["length"]:
+ result = f(*a)
+ if not hasattr(result, "__len__"):
+ raise TypeError("Return type of pandas_udf should be a Pandas.Series")
+ if len(result) != len(a[0]):
raise RuntimeError("Result vector from pandas_udf was not the required length: "
- "expected %d, got %d\nUse input vector length or kwargs['length']"
- % (kwargs["length"], len(result)))
- return result, arrow_return_type
- return lambda *a: verify_result_length(*a)
+ "expected %d, got %d" % (len(a[0]), len(result)))
+ return result
+ return lambda *a: (verify_result_length(*a), arrow_return_type)
def read_single_udf(pickleSer, infile, eval_type):
@@ -99,11 +96,9 @@ def read_single_udf(pickleSer, infile, eval_type):
if row_func is None:
row_func = f
else:
- row_func = chain(row_func, f, eval_type)
+ row_func = chain(row_func, f)
# the last returnType will be the return type of UDF
if eval_type == PythonEvalType.SQL_PANDAS_UDF:
- # A pandas_udf will take kwargs as the last argument
- arg_offsets = arg_offsets + [-1]
return arg_offsets, wrap_pandas_udf(row_func, return_type)
else:
return arg_offsets, wrap_udf(row_func, return_type)
http://git-wip-us.apache.org/repos/asf/spark/blob/d8e825e3/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
index f8bdc1e..5e72cd2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
@@ -51,10 +51,12 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
outputIterator.map(new ArrowPayload(_)), context)
// Verify that the output schema is correct
- val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex
- .map { case (attr, i) => attr.withName(s"_$i") })
- assert(schemaOut.equals(outputRowIterator.schema),
- s"Invalid schema from pandas_udf: expected $schemaOut, got ${outputRowIterator.schema}")
+ if (outputRowIterator.hasNext) {
+ val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex
+ .map { case (attr, i) => attr.withName(s"_$i") })
+ assert(schemaOut.equals(outputRowIterator.schema),
+ s"Invalid schema from pandas_udf: expected $schemaOut, got ${outputRowIterator.schema}")
+ }
outputRowIterator
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org