You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/09/22 08:20:28 UTC
spark git commit: [SPARK-21190][PYSPARK] Python Vectorized UDFs
Repository: spark
Updated Branches:
refs/heads/master 8f130ad40 -> 27fc536d9
[SPARK-21190][PYSPARK] Python Vectorized UDFs
This PR adds vectorized UDFs to the Python API
**Proposed API**
Introduce a flag to turn on vectorization for a defined UDF, for example:
```
pandas_udf(DoubleType())
def plus(a, b)
return a + b
```
or
```
plus = pandas_udf(lambda a, b: a + b, DoubleType())
```
Usage is the same as normal UDFs
0-parameter UDFs
pandas_udf functions can declare an optional `**kwargs` and when evaluated, will contain a key "size" that will give the required length of the output. For example:
```
pandas_udf(LongType())
def f0(**kwargs):
return pd.Series(1).repeat(kwargs["size"])
df.select(f0())
```
Added new unit tests in pyspark.sql that are enabled if pyarrow and Pandas are available.
- [x] Fix support for promoted types with null values
- [ ] Discuss 0-param UDF API (use of kwargs)
- [x] Add tests for chained UDFs
- [ ] Discuss behavior when pyarrow not installed / enabled
- [ ] Cleanup pydoc and add user docs
Author: Bryan Cutler <cu...@gmail.com>
Author: Takuya UESHIN <ue...@databricks.com>
Closes #18659 from BryanCutler/arrow-vectorized-udfs-SPARK-21404.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/27fc536d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/27fc536d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/27fc536d
Branch: refs/heads/master
Commit: 27fc536d9a54eccef7d1cbbe2a6a008043d62ba4
Parents: 8f130ad
Author: Bryan Cutler <cu...@gmail.com>
Authored: Fri Sep 22 16:17:41 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Fri Sep 22 16:17:50 2017 +0800
----------------------------------------------------------------------
.../org/apache/spark/api/python/PythonRDD.scala | 22 ++-
python/pyspark/serializers.py | 65 +++++-
python/pyspark/sql/functions.py | 49 +++--
python/pyspark/sql/tests.py | 197 +++++++++++++++++++
python/pyspark/sql/types.py | 27 +++
python/pyspark/worker.py | 57 ++++--
.../execution/python/ArrowEvalPythonExec.scala | 61 ++++++
.../execution/python/BatchEvalPythonExec.scala | 193 ++++++------------
.../sql/execution/python/EvalPythonExec.scala | 142 +++++++++++++
.../execution/python/ExtractPythonUDFs.scala | 11 +-
.../spark/sql/execution/python/PythonUDF.scala | 3 +-
.../python/UserDefinedPythonFunction.scala | 5 +-
.../python/BatchEvalPythonExecSuite.scala | 7 +-
13 files changed, 666 insertions(+), 173 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/27fc536d/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 3377101..86d0405 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -83,10 +83,23 @@ private[spark] case class PythonFunction(
*/
private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction])
+/**
+ * Enumerate the type of command that will be sent to the Python worker
+ */
+private[spark] object PythonEvalType {
+ val NON_UDF = 0
+ val SQL_BATCHED_UDF = 1
+ val SQL_PANDAS_UDF = 2
+}
+
private[spark] object PythonRunner {
def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = {
new PythonRunner(
- Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuse_worker, false, Array(Array(0)))
+ Seq(ChainedPythonFunctions(Seq(func))),
+ bufferSize,
+ reuse_worker,
+ PythonEvalType.NON_UDF,
+ Array(Array(0)))
}
}
@@ -100,7 +113,7 @@ private[spark] class PythonRunner(
funcs: Seq[ChainedPythonFunctions],
bufferSize: Int,
reuse_worker: Boolean,
- isUDF: Boolean,
+ evalType: Int,
argOffsets: Array[Array[Int]])
extends Logging {
@@ -309,8 +322,8 @@ private[spark] class PythonRunner(
}
dataOut.flush()
// Serialized command:
- if (isUDF) {
- dataOut.writeInt(1)
+ dataOut.writeInt(evalType)
+ if (evalType != PythonEvalType.NON_UDF) {
dataOut.writeInt(funcs.length)
funcs.zip(argOffsets).foreach { case (chained, offsets) =>
dataOut.writeInt(offsets.length)
@@ -324,7 +337,6 @@ private[spark] class PythonRunner(
}
}
} else {
- dataOut.writeInt(0)
val command = funcs.head.funcs.head.command
dataOut.writeInt(command.length)
dataOut.write(command)
http://git-wip-us.apache.org/repos/asf/spark/blob/27fc536d/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 660b19a..887c702 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -81,6 +81,12 @@ class SpecialLengths(object):
NULL = -5
+class PythonEvalType(object):
+ NON_UDF = 0
+ SQL_BATCHED_UDF = 1
+ SQL_PANDAS_UDF = 2
+
+
class Serializer(object):
def dump_stream(self, iterator, stream):
@@ -187,8 +193,14 @@ class ArrowSerializer(FramedSerializer):
Serializes an Arrow stream.
"""
- def dumps(self, obj):
- raise NotImplementedError
+ def dumps(self, batch):
+ import pyarrow as pa
+ import io
+ sink = io.BytesIO()
+ writer = pa.RecordBatchFileWriter(sink, batch.schema)
+ writer.write_batch(batch)
+ writer.close()
+ return sink.getvalue()
def loads(self, obj):
import pyarrow as pa
@@ -199,6 +211,55 @@ class ArrowSerializer(FramedSerializer):
return "ArrowSerializer"
+class ArrowPandasSerializer(ArrowSerializer):
+ """
+ Serializes Pandas.Series as Arrow data.
+ """
+
+ def __init__(self):
+ super(ArrowPandasSerializer, self).__init__()
+
+ def dumps(self, series):
+ """
+ Make an ArrowRecordBatch from a Pandas Series and serialize. Input is a single series or
+ a list of series accompanied by an optional pyarrow type to coerce the data to.
+ """
+ import pyarrow as pa
+ # Make input conform to [(series1, type1), (series2, type2), ...]
+ if not isinstance(series, (list, tuple)) or \
+ (len(series) == 2 and isinstance(series[1], pa.DataType)):
+ series = [series]
+ series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)
+
+ # If a nullable integer series has been promoted to floating point with NaNs, need to cast
+ # NOTE: this is not necessary with Arrow >= 0.7
+ def cast_series(s, t):
+ if t is None or s.dtype == t.to_pandas_dtype():
+ return s
+ else:
+ return s.fillna(0).astype(t.to_pandas_dtype(), copy=False)
+
+ arrs = [pa.Array.from_pandas(cast_series(s, t), mask=s.isnull(), type=t) for s, t in series]
+ batch = pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])
+ return super(ArrowPandasSerializer, self).dumps(batch)
+
+ def loads(self, obj):
+ """
+ Deserialize an ArrowRecordBatch to an Arrow table and return as a list of pandas.Series
+ followed by a dictionary containing length of the loaded batches.
+ """
+ import pyarrow as pa
+ reader = pa.RecordBatchFileReader(pa.BufferReader(obj))
+ batches = [reader.get_batch(i) for i in xrange(reader.num_record_batches)]
+ # NOTE: a 0-parameter pandas_udf will produce an empty batch that can have num_rows set
+ num_rows = sum((batch.num_rows for batch in batches))
+ table = pa.Table.from_batches(batches)
+ return [c.to_pandas() for c in table.itercolumns()] + [{"length": num_rows}]
+
+ def __repr__(self):
+ return "ArrowPandasSerializer"
+
+
class BatchedSerializer(Serializer):
"""
http://git-wip-us.apache.org/repos/asf/spark/blob/27fc536d/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 57068fb..46e3a85 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2044,7 +2044,7 @@ class UserDefinedFunction(object):
.. versionadded:: 1.3
"""
- def __init__(self, func, returnType, name=None):
+ def __init__(self, func, returnType, name=None, vectorized=False):
if not callable(func):
raise TypeError(
"Not a function or callable (__call__ is not defined): "
@@ -2058,6 +2058,7 @@ class UserDefinedFunction(object):
self._name = name or (
func.__name__ if hasattr(func, '__name__')
else func.__class__.__name__)
+ self._vectorized = vectorized
@property
def returnType(self):
@@ -2089,7 +2090,7 @@ class UserDefinedFunction(object):
wrapped_func = _wrap_function(sc, self.func, self.returnType)
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
- self._name, wrapped_func, jdt)
+ self._name, wrapped_func, jdt, self._vectorized)
return judf
def __call__(self, *cols):
@@ -2123,6 +2124,22 @@ class UserDefinedFunction(object):
return wrapper
+def _create_udf(f, returnType, vectorized):
+
+ def _udf(f, returnType=StringType(), vectorized=vectorized):
+ udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized)
+ return udf_obj._wrapped()
+
+ # decorator @udf, @udf(), @udf(dataType()), or similar with @pandas_udf
+ if f is None or isinstance(f, (str, DataType)):
+ # If DataType has been passed as a positional argument
+ # for decorator use it as a returnType
+ return_type = f or returnType
+ return functools.partial(_udf, returnType=return_type, vectorized=vectorized)
+ else:
+ return _udf(f=f, returnType=returnType, vectorized=vectorized)
+
+
@since(1.3)
def udf(f=None, returnType=StringType()):
"""Creates a :class:`Column` expression representing a user defined function (UDF).
@@ -2154,18 +2171,26 @@ def udf(f=None, returnType=StringType()):
| 8| JOHN DOE| 22|
+----------+--------------+------------+
"""
- def _udf(f, returnType=StringType()):
- udf_obj = UserDefinedFunction(f, returnType)
- return udf_obj._wrapped()
+ return _create_udf(f, returnType=returnType, vectorized=False)
- # decorator @udf, @udf() or @udf(dataType())
- if f is None or isinstance(f, (str, DataType)):
- # If DataType has been passed as a positional argument
- # for decorator use it as a returnType
- return_type = f or returnType
- return functools.partial(_udf, returnType=return_type)
+
+@since(2.3)
+def pandas_udf(f=None, returnType=StringType()):
+ """
+ Creates a :class:`Column` expression representing a user defined function (UDF) that accepts
+ `Pandas.Series` as input arguments and outputs a `Pandas.Series` of the same length.
+
+ :param f: python function if used as a standalone function
+ :param returnType: a :class:`pyspark.sql.types.DataType` object
+
+ # TODO: doctest
+ """
+ import inspect
+ # If function "f" does not define the optional kwargs, then wrap with a kwargs placeholder
+ if inspect.getargspec(f).keywords is None:
+ return _create_udf(lambda *a, **kwargs: f(*a), returnType=returnType, vectorized=True)
else:
- return _udf(f=f, returnType=returnType)
+ return _create_udf(f, returnType=returnType, vectorized=True)
blacklist = ['map', 'since', 'ignore_unicode_prefix']
http://git-wip-us.apache.org/repos/asf/spark/blob/27fc536d/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 6e7ddf9..ab76c48 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3136,6 +3136,203 @@ class ArrowTests(ReusedPySparkTestCase):
self.assertTrue(pdf.empty)
+@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
+class VectorizedUDFTests(ReusedPySparkTestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ ReusedPySparkTestCase.setUpClass()
+ cls.spark = SparkSession(cls.sc)
+
+ @classmethod
+ def tearDownClass(cls):
+ ReusedPySparkTestCase.tearDownClass()
+ cls.spark.stop()
+
+ def test_vectorized_udf_basic(self):
+ from pyspark.sql.functions import pandas_udf, col
+ df = self.spark.range(10).select(
+ col('id').cast('string').alias('str'),
+ col('id').cast('int').alias('int'),
+ col('id').alias('long'),
+ col('id').cast('float').alias('float'),
+ col('id').cast('double').alias('double'),
+ col('id').cast('boolean').alias('bool'))
+ f = lambda x: x
+ str_f = pandas_udf(f, StringType())
+ int_f = pandas_udf(f, IntegerType())
+ long_f = pandas_udf(f, LongType())
+ float_f = pandas_udf(f, FloatType())
+ double_f = pandas_udf(f, DoubleType())
+ bool_f = pandas_udf(f, BooleanType())
+ res = df.select(str_f(col('str')), int_f(col('int')),
+ long_f(col('long')), float_f(col('float')),
+ double_f(col('double')), bool_f(col('bool')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_null_boolean(self):
+ from pyspark.sql.functions import pandas_udf, col
+ data = [(True,), (True,), (None,), (False,)]
+ schema = StructType().add("bool", BooleanType())
+ df = self.spark.createDataFrame(data, schema)
+ bool_f = pandas_udf(lambda x: x, BooleanType())
+ res = df.select(bool_f(col('bool')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_null_byte(self):
+ from pyspark.sql.functions import pandas_udf, col
+ data = [(None,), (2,), (3,), (4,)]
+ schema = StructType().add("byte", ByteType())
+ df = self.spark.createDataFrame(data, schema)
+ byte_f = pandas_udf(lambda x: x, ByteType())
+ res = df.select(byte_f(col('byte')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_null_short(self):
+ from pyspark.sql.functions import pandas_udf, col
+ data = [(None,), (2,), (3,), (4,)]
+ schema = StructType().add("short", ShortType())
+ df = self.spark.createDataFrame(data, schema)
+ short_f = pandas_udf(lambda x: x, ShortType())
+ res = df.select(short_f(col('short')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_null_int(self):
+ from pyspark.sql.functions import pandas_udf, col
+ data = [(None,), (2,), (3,), (4,)]
+ schema = StructType().add("int", IntegerType())
+ df = self.spark.createDataFrame(data, schema)
+ int_f = pandas_udf(lambda x: x, IntegerType())
+ res = df.select(int_f(col('int')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_null_long(self):
+ from pyspark.sql.functions import pandas_udf, col
+ data = [(None,), (2,), (3,), (4,)]
+ schema = StructType().add("long", LongType())
+ df = self.spark.createDataFrame(data, schema)
+ long_f = pandas_udf(lambda x: x, LongType())
+ res = df.select(long_f(col('long')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_null_float(self):
+ from pyspark.sql.functions import pandas_udf, col
+ data = [(3.0,), (5.0,), (-1.0,), (None,)]
+ schema = StructType().add("float", FloatType())
+ df = self.spark.createDataFrame(data, schema)
+ float_f = pandas_udf(lambda x: x, FloatType())
+ res = df.select(float_f(col('float')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_null_double(self):
+ from pyspark.sql.functions import pandas_udf, col
+ data = [(3.0,), (5.0,), (-1.0,), (None,)]
+ schema = StructType().add("double", DoubleType())
+ df = self.spark.createDataFrame(data, schema)
+ double_f = pandas_udf(lambda x: x, DoubleType())
+ res = df.select(double_f(col('double')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_null_string(self):
+ from pyspark.sql.functions import pandas_udf, col
+ data = [("foo",), (None,), ("bar",), ("bar",)]
+ schema = StructType().add("str", StringType())
+ df = self.spark.createDataFrame(data, schema)
+ str_f = pandas_udf(lambda x: x, StringType())
+ res = df.select(str_f(col('str')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_zero_parameter(self):
+ from pyspark.sql.functions import pandas_udf
+ import pandas as pd
+ df = self.spark.range(10)
+ f0 = pandas_udf(lambda **kwargs: pd.Series(1).repeat(kwargs['length']), LongType())
+ res = df.select(f0())
+ self.assertEquals(df.select(lit(1)).collect(), res.collect())
+
+ def test_vectorized_udf_datatype_string(self):
+ from pyspark.sql.functions import pandas_udf, col
+ df = self.spark.range(10).select(
+ col('id').cast('string').alias('str'),
+ col('id').cast('int').alias('int'),
+ col('id').alias('long'),
+ col('id').cast('float').alias('float'),
+ col('id').cast('double').alias('double'),
+ col('id').cast('boolean').alias('bool'))
+ f = lambda x: x
+ str_f = pandas_udf(f, 'string')
+ int_f = pandas_udf(f, 'integer')
+ long_f = pandas_udf(f, 'long')
+ float_f = pandas_udf(f, 'float')
+ double_f = pandas_udf(f, 'double')
+ bool_f = pandas_udf(f, 'boolean')
+ res = df.select(str_f(col('str')), int_f(col('int')),
+ long_f(col('long')), float_f(col('float')),
+ double_f(col('double')), bool_f(col('bool')))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_complex(self):
+ from pyspark.sql.functions import pandas_udf, col, expr
+ df = self.spark.range(10).select(
+ col('id').cast('int').alias('a'),
+ col('id').cast('int').alias('b'),
+ col('id').cast('double').alias('c'))
+ add = pandas_udf(lambda x, y: x + y, IntegerType())
+ power2 = pandas_udf(lambda x: 2 ** x, IntegerType())
+ mul = pandas_udf(lambda x, y: x * y, DoubleType())
+ res = df.select(add(col('a'), col('b')), power2(col('a')), mul(col('b'), col('c')))
+ expected = df.select(expr('a + b'), expr('power(2, a)'), expr('b * c'))
+ self.assertEquals(expected.collect(), res.collect())
+
+ def test_vectorized_udf_exception(self):
+ from pyspark.sql.functions import pandas_udf, col
+ df = self.spark.range(10)
+ raise_exception = pandas_udf(lambda x: x * (1 / 0), LongType())
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(Exception, 'division( or modulo)? by zero'):
+ df.select(raise_exception(col('id'))).collect()
+
+ def test_vectorized_udf_invalid_length(self):
+ from pyspark.sql.functions import pandas_udf, col
+ import pandas as pd
+ df = self.spark.range(10)
+ raise_exception = pandas_udf(lambda: pd.Series(1), LongType())
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(
+ Exception,
+ 'Result vector from pandas_udf was not the required length'):
+ df.select(raise_exception()).collect()
+
+ def test_vectorized_udf_mix_udf(self):
+ from pyspark.sql.functions import pandas_udf, udf, col
+ df = self.spark.range(10)
+ row_by_row_udf = udf(lambda x: x, LongType())
+ pd_udf = pandas_udf(lambda x: x, LongType())
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(
+ Exception,
+ 'Can not mix vectorized and non-vectorized UDFs'):
+ df.select(row_by_row_udf(col('id')), pd_udf(col('id'))).collect()
+
+ def test_vectorized_udf_chained(self):
+ from pyspark.sql.functions import pandas_udf, col
+ df = self.spark.range(10).toDF('x')
+ f = pandas_udf(lambda x: x + 1, LongType())
+ g = pandas_udf(lambda x: x - 1, LongType())
+ res = df.select(g(f(col('x'))))
+ self.assertEquals(df.collect(), res.collect())
+
+ def test_vectorized_udf_wrong_return_type(self):
+ from pyspark.sql.functions import pandas_udf, col
+ df = self.spark.range(10).toDF('x')
+ f = pandas_udf(lambda x: x * 1.0, StringType())
+ with QuietTest(self.sc):
+ with self.assertRaisesRegexp(
+ Exception,
+ 'Invalid.*type.*string'):
+ df.select(f(col('x'))).collect()
+
+
if __name__ == "__main__":
from pyspark.sql.tests import *
if xmlrunner:
http://git-wip-us.apache.org/repos/asf/spark/blob/27fc536d/python/pyspark/sql/types.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index aaf520f..ebdc11c 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1597,6 +1597,33 @@ register_input_converter(DatetimeConverter())
register_input_converter(DateConverter())
+def toArrowType(dt):
+ """ Convert Spark data type to pyarrow type
+ """
+ import pyarrow as pa
+ if type(dt) == BooleanType:
+ arrow_type = pa.bool_()
+ elif type(dt) == ByteType:
+ arrow_type = pa.int8()
+ elif type(dt) == ShortType:
+ arrow_type = pa.int16()
+ elif type(dt) == IntegerType:
+ arrow_type = pa.int32()
+ elif type(dt) == LongType:
+ arrow_type = pa.int64()
+ elif type(dt) == FloatType:
+ arrow_type = pa.float32()
+ elif type(dt) == DoubleType:
+ arrow_type = pa.float64()
+ elif type(dt) == DecimalType:
+ arrow_type = pa.decimal(dt.precision, dt.scale)
+ elif type(dt) == StringType:
+ arrow_type = pa.string()
+ else:
+ raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
+ return arrow_type
+
+
def _test():
import doctest
from pyspark.context import SparkContext
http://git-wip-us.apache.org/repos/asf/spark/blob/27fc536d/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index baaa3fe..0e35cf7 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -30,7 +30,9 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.taskcontext import TaskContext
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
- write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, BatchedSerializer
+ write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \
+ BatchedSerializer, ArrowPandasSerializer
+from pyspark.sql.types import toArrowType
from pyspark import shuffle
pickleSer = PickleSerializer()
@@ -58,9 +60,12 @@ def read_command(serializer, file):
return command
-def chain(f, g):
- """chain two function together """
- return lambda *a: g(f(*a))
+def chain(f, g, eval_type):
+ """chain two functions together """
+ if eval_type == PythonEvalType.SQL_PANDAS_UDF:
+ return lambda *a, **kwargs: g(f(*a, **kwargs), **kwargs)
+ else:
+ return lambda *a: g(f(*a))
def wrap_udf(f, return_type):
@@ -71,7 +76,21 @@ def wrap_udf(f, return_type):
return lambda *a: f(*a)
-def read_single_udf(pickleSer, infile):
+def wrap_pandas_udf(f, return_type):
+ arrow_return_type = toArrowType(return_type)
+
+ def verify_result_length(*a):
+ kwargs = a[-1]
+ result = f(*a[:-1], **kwargs)
+ if len(result) != kwargs["length"]:
+ raise RuntimeError("Result vector from pandas_udf was not the required length: "
+ "expected %d, got %d\nUse input vector length or kwargs['length']"
+ % (kwargs["length"], len(result)))
+ return result, arrow_return_type
+ return lambda *a: verify_result_length(*a)
+
+
+def read_single_udf(pickleSer, infile, eval_type):
num_arg = read_int(infile)
arg_offsets = [read_int(infile) for i in range(num_arg)]
row_func = None
@@ -80,17 +99,22 @@ def read_single_udf(pickleSer, infile):
if row_func is None:
row_func = f
else:
- row_func = chain(row_func, f)
+ row_func = chain(row_func, f, eval_type)
# the last returnType will be the return type of UDF
- return arg_offsets, wrap_udf(row_func, return_type)
+ if eval_type == PythonEvalType.SQL_PANDAS_UDF:
+ # A pandas_udf will take kwargs as the last argument
+ arg_offsets = arg_offsets + [-1]
+ return arg_offsets, wrap_pandas_udf(row_func, return_type)
+ else:
+ return arg_offsets, wrap_udf(row_func, return_type)
-def read_udfs(pickleSer, infile):
+def read_udfs(pickleSer, infile, eval_type):
num_udfs = read_int(infile)
udfs = {}
call_udf = []
for i in range(num_udfs):
- arg_offsets, udf = read_single_udf(pickleSer, infile)
+ arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type)
udfs['f%d' % i] = udf
args = ["a[%d]" % o for o in arg_offsets]
call_udf.append("f%d(%s)" % (i, ", ".join(args)))
@@ -102,7 +126,12 @@ def read_udfs(pickleSer, infile):
mapper = eval(mapper_str, udfs)
func = lambda _, it: map(mapper, it)
- ser = BatchedSerializer(PickleSerializer(), 100)
+
+ if eval_type == PythonEvalType.SQL_PANDAS_UDF:
+ ser = ArrowPandasSerializer()
+ else:
+ ser = BatchedSerializer(PickleSerializer(), 100)
+
# profiling is not supported for UDF
return func, None, ser, ser
@@ -159,11 +188,11 @@ def main(infile, outfile):
_broadcastRegistry.pop(bid)
_accumulatorRegistry.clear()
- is_sql_udf = read_int(infile)
- if is_sql_udf:
- func, profiler, deserializer, serializer = read_udfs(pickleSer, infile)
- else:
+ eval_type = read_int(infile)
+ if eval_type == PythonEvalType.NON_UDF:
func, profiler, deserializer, serializer = read_command(pickleSer, infile)
+ else:
+ func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type)
init_time = time.time()
http://git-wip-us.apache.org/repos/asf/spark/blob/27fc536d/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
new file mode 100644
index 0000000..f8bdc1e
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.TaskContext
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonRunner}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * A physical plan that evaluates a [[PythonUDF]],
+ */
+case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
+ extends EvalPythonExec(udfs, output, child) {
+
+ protected override def evaluate(
+ funcs: Seq[ChainedPythonFunctions],
+ bufferSize: Int,
+ reuseWorker: Boolean,
+ argOffsets: Array[Array[Int]],
+ iter: Iterator[InternalRow],
+ schema: StructType,
+ context: TaskContext): Iterator[InternalRow] = {
+ val inputIterator = ArrowConverters.toPayloadIterator(
+ iter, schema, conf.arrowMaxRecordsPerBatch, context).map(_.asPythonSerializable)
+
+ // Output iterator for results from Python.
+ val outputIterator = new PythonRunner(
+ funcs, bufferSize, reuseWorker, PythonEvalType.SQL_PANDAS_UDF, argOffsets)
+ .compute(inputIterator, context.partitionId(), context)
+
+ val outputRowIterator = ArrowConverters.fromPayloadIterator(
+ outputIterator.map(new ArrowPayload(_)), context)
+
+ // Verify that the output schema is correct
+ val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex
+ .map { case (attr, i) => attr.withName(s"_$i") })
+ assert(schemaOut.equals(outputRowIterator.schema),
+ s"Invalid schema from pandas_udf: expected $schemaOut, got ${outputRowIterator.schema}")
+
+ outputRowIterator
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/27fc536d/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
index 3e176e2..2978eac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
@@ -17,153 +17,78 @@
package org.apache.spark.sql.execution.python
-import java.io.File
-
import scala.collection.JavaConverters._
-import scala.collection.mutable.ArrayBuffer
import net.razorvine.pickle.{Pickler, Unpickler}
-import org.apache.spark.{SparkEnv, TaskContext}
-import org.apache.spark.api.python.{ChainedPythonFunctions, PythonRunner}
-import org.apache.spark.rdd.RDD
+import org.apache.spark.TaskContext
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonRunner}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.types.{DataType, StructField, StructType}
-import org.apache.spark.util.Utils
-
+import org.apache.spark.sql.types.{StructField, StructType}
/**
- * A physical plan that evaluates a [[PythonUDF]], one partition of tuples at a time.
- *
- * Python evaluation works by sending the necessary (projected) input data via a socket to an
- * external Python process, and combine the result from the Python process with the original row.
- *
- * For each row we send to Python, we also put it in a queue first. For each output row from Python,
- * we drain the queue to find the original input row. Note that if the Python process is way too
- * slow, this could lead to the queue growing unbounded and spill into disk when run out of memory.
- *
- * Here is a diagram to show how this works:
- *
- * Downstream (for parent)
- * / \
- * / socket (output of UDF)
- * / \
- * RowQueue Python
- * \ /
- * \ socket (input of UDF)
- * \ /
- * upstream (from child)
- *
- * The rows sent to and received from Python are packed into batches (100 rows) and serialized,
- * there should be always some rows buffered in the socket or Python process, so the pulling from
- * RowQueue ALWAYS happened after pushing into it.
+ * A physical plan that evaluates a [[PythonUDF]]
*/
case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
- extends SparkPlan {
-
- def children: Seq[SparkPlan] = child :: Nil
-
- override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
-
- private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
- udf.children match {
- case Seq(u: PythonUDF) =>
- val (chained, children) = collectFunctions(u)
- (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
- case children =>
- // There should not be any other UDFs, or the children can't be evaluated directly.
- assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
- (ChainedPythonFunctions(Seq(udf.func)), udf.children)
- }
- }
-
- protected override def doExecute(): RDD[InternalRow] = {
- val inputRDD = child.execute().map(_.copy())
- val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
- val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
-
- inputRDD.mapPartitions { iter =>
- EvaluatePython.registerPicklers() // register pickler for Row
-
- // The queue used to buffer input rows so we can drain it to
- // combine input with output from Python.
- val queue = HybridRowQueue(TaskContext.get().taskMemoryManager(),
- new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length)
- TaskContext.get().addTaskCompletionListener({ ctx =>
- queue.close()
- })
-
- val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip
-
- // flatten all the arguments
- val allInputs = new ArrayBuffer[Expression]
- val dataTypes = new ArrayBuffer[DataType]
- val argOffsets = inputs.map { input =>
- input.map { e =>
- if (allInputs.exists(_.semanticEquals(e))) {
- allInputs.indexWhere(_.semanticEquals(e))
- } else {
- allInputs += e
- dataTypes += e.dataType
- allInputs.length - 1
- }
- }.toArray
- }.toArray
- val projection = newMutableProjection(allInputs, child.output)
- val schema = StructType(dataTypes.map(dt => StructField("", dt)))
- val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython)
-
- // enable memo iff we serialize the row with schema (schema and class should be memorized)
- val pickle = new Pickler(needConversion)
- // Input iterator to Python: input rows are grouped so we send them in batches to Python.
- // For each row, add it to the queue.
- val inputIterator = iter.map { inputRow =>
- queue.add(inputRow.asInstanceOf[UnsafeRow])
- val row = projection(inputRow)
- if (needConversion) {
- EvaluatePython.toJava(row, schema)
- } else {
- // fast path for these types that does not need conversion in Python
- val fields = new Array[Any](row.numFields)
- var i = 0
- while (i < row.numFields) {
- val dt = dataTypes(i)
- fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
- i += 1
- }
- fields
- }
- }.grouped(100).map(x => pickle.dumps(x.toArray))
-
- val context = TaskContext.get()
-
- // Output iterator for results from Python.
- val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, argOffsets)
- .compute(inputIterator, context.partitionId(), context)
-
- val unpickle = new Unpickler
- val mutableRow = new GenericInternalRow(1)
- val joined = new JoinedRow
- val resultType = if (udfs.length == 1) {
- udfs.head.dataType
+ extends EvalPythonExec(udfs, output, child) {
+
+ protected override def evaluate(
+ funcs: Seq[ChainedPythonFunctions],
+ bufferSize: Int,
+ reuseWorker: Boolean,
+ argOffsets: Array[Array[Int]],
+ iter: Iterator[InternalRow],
+ schema: StructType,
+ context: TaskContext): Iterator[InternalRow] = {
+ EvaluatePython.registerPicklers() // register pickler for Row
+
+ val dataTypes = schema.map(_.dataType)
+ val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython)
+
+ // enable memo iff we serialize the row with schema (schema and class should be memorized)
+ val pickle = new Pickler(needConversion)
+ // Input iterator to Python: input rows are grouped so we send them in batches to Python.
+ // For each row, add it to the queue.
+ val inputIterator = iter.map { row =>
+ if (needConversion) {
+ EvaluatePython.toJava(row, schema)
} else {
- StructType(udfs.map(u => StructField("", u.dataType, u.nullable)))
- }
- val resultProj = UnsafeProjection.create(output, output)
- outputIterator.flatMap { pickedResult =>
- val unpickledBatch = unpickle.loads(pickedResult)
- unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
- }.map { result =>
- val row = if (udfs.length == 1) {
- // fast path for single UDF
- mutableRow(0) = EvaluatePython.fromJava(result, resultType)
- mutableRow
- } else {
- EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow]
+ // fast path for these types that does not need conversion in Python
+ val fields = new Array[Any](row.numFields)
+ var i = 0
+ while (i < row.numFields) {
+ val dt = dataTypes(i)
+ fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
+ i += 1
}
- resultProj(joined(queue.remove(), row))
+ fields
+ }
+ }.grouped(100).map(x => pickle.dumps(x.toArray))
+
+ // Output iterator for results from Python.
+ val outputIterator = new PythonRunner(
+ funcs, bufferSize, reuseWorker, PythonEvalType.SQL_BATCHED_UDF, argOffsets)
+ .compute(inputIterator, context.partitionId(), context)
+
+ val unpickle = new Unpickler
+ val mutableRow = new GenericInternalRow(1)
+ val resultType = if (udfs.length == 1) {
+ udfs.head.dataType
+ } else {
+ StructType(udfs.map(u => StructField("", u.dataType, u.nullable)))
+ }
+ outputIterator.flatMap { pickedResult =>
+ val unpickledBatch = unpickle.loads(pickedResult)
+ unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
+ }.map { result =>
+ if (udfs.length == 1) {
+ // fast path for single UDF
+ mutableRow(0) = EvaluatePython.fromJava(result, resultType)
+ mutableRow
+ } else {
+ EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow]
}
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/27fc536d/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala
new file mode 100644
index 0000000..860dc78
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io.File
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.api.python.ChainedPythonFunctions
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.types.{DataType, StructField, StructType}
+import org.apache.spark.util.Utils
+
+
+/**
+ * A physical plan that evaluates a [[PythonUDF]], one partition of tuples at a time.
+ *
+ * Python evaluation works by sending the necessary (projected) input data via a socket to an
+ * external Python process, and combine the result from the Python process with the original row.
+ *
+ * For each row we send to Python, we also put it in a queue first. For each output row from Python,
+ * we drain the queue to find the original input row. Note that if the Python process is way too
+ * slow, this could lead to the queue growing unbounded and spill into disk when run out of memory.
+ *
+ * Here is a diagram to show how this works:
+ *
+ * Downstream (for parent)
+ * / \
+ * / socket (output of UDF)
+ * / \
+ * RowQueue Python
+ * \ /
+ * \ socket (input of UDF)
+ * \ /
+ * upstream (from child)
+ *
+ * The rows sent to and received from Python are packed into batches (100 rows) and serialized,
+ * there should be always some rows buffered in the socket or Python process, so the pulling from
+ * RowQueue ALWAYS happened after pushing into it.
+ */
+abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
+ extends SparkPlan {
+
+ def children: Seq[SparkPlan] = child :: Nil
+
+ override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
+
+ private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
+ udf.children match {
+ case Seq(u: PythonUDF) =>
+ val (chained, children) = collectFunctions(u)
+ (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
+ case children =>
+ // There should not be any other UDFs, or the children can't be evaluated directly.
+ assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
+ (ChainedPythonFunctions(Seq(udf.func)), udf.children)
+ }
+ }
+
+ protected def evaluate(
+ funcs: Seq[ChainedPythonFunctions],
+ bufferSize: Int,
+ reuseWorker: Boolean,
+ argOffsets: Array[Array[Int]],
+ iter: Iterator[InternalRow],
+ schema: StructType,
+ context: TaskContext): Iterator[InternalRow]
+
+ protected override def doExecute(): RDD[InternalRow] = {
+ val inputRDD = child.execute().map(_.copy())
+ val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
+ val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
+
+ inputRDD.mapPartitions { iter =>
+ val context = TaskContext.get()
+
+ // The queue used to buffer input rows so we can drain it to
+ // combine input with output from Python.
+ val queue = HybridRowQueue(context.taskMemoryManager(),
+ new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length)
+ context.addTaskCompletionListener { ctx =>
+ queue.close()
+ }
+
+ val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip
+
+ // flatten all the arguments
+ val allInputs = new ArrayBuffer[Expression]
+ val dataTypes = new ArrayBuffer[DataType]
+ val argOffsets = inputs.map { input =>
+ input.map { e =>
+ if (allInputs.exists(_.semanticEquals(e))) {
+ allInputs.indexWhere(_.semanticEquals(e))
+ } else {
+ allInputs += e
+ dataTypes += e.dataType
+ allInputs.length - 1
+ }
+ }.toArray
+ }.toArray
+ val projection = newMutableProjection(allInputs, child.output)
+ val schema = StructType(dataTypes.zipWithIndex.map { case (dt, i) =>
+ StructField(s"_$i", dt)
+ })
+
+ // Add rows to queue to join later with the result.
+ val projectedRowIter = iter.map { inputRow =>
+ queue.add(inputRow.asInstanceOf[UnsafeRow])
+ projection(inputRow)
+ }
+
+ val outputRowIterator = evaluate(
+ pyFuncs, bufferSize, reuseWorker, argOffsets, projectedRowIter, schema, context)
+
+ val joined = new JoinedRow
+ val resultProj = UnsafeProjection.create(output, output)
+
+ outputRowIterator.map { outputRow =>
+ resultProj(joined(queue.remove(), outputRow))
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/27fc536d/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index 882a5ce..fec456d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -138,7 +138,16 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
val resultAttrs = udfs.zipWithIndex.map { case (u, i) =>
AttributeReference(s"pythonUDF$i", u.dataType)()
}
- val evaluation = BatchEvalPythonExec(validUdfs, child.output ++ resultAttrs, child)
+
+ val evaluation = validUdfs.partition(_.vectorized) match {
+ case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty =>
+ ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child)
+ case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty =>
+ BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child)
+ case _ =>
+ throw new IllegalArgumentException("Can not mix vectorized and non-vectorized UDFs")
+ }
+
attributeMap ++= validUdfs.zip(resultAttrs)
evaluation
} else {
http://git-wip-us.apache.org/repos/asf/spark/blob/27fc536d/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
index 7ebbdb9..84a6d9e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
@@ -28,7 +28,8 @@ case class PythonUDF(
name: String,
func: PythonFunction,
dataType: DataType,
- children: Seq[Expression])
+ children: Seq[Expression],
+ vectorized: Boolean)
extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression {
override def toString: String = s"$name(${children.mkString(", ")})"
http://git-wip-us.apache.org/repos/asf/spark/blob/27fc536d/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index 0d39c8f..a30a80a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -28,10 +28,11 @@ import org.apache.spark.sql.types.DataType
case class UserDefinedPythonFunction(
name: String,
func: PythonFunction,
- dataType: DataType) {
+ dataType: DataType,
+ vectorized: Boolean) {
def builder(e: Seq[Expression]): PythonUDF = {
- PythonUDF(name, func, dataType, e)
+ PythonUDF(name, func, dataType, e, vectorized)
}
/** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */
http://git-wip-us.apache.org/repos/asf/spark/blob/27fc536d/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
index bbd9484..153e6e1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
@@ -105,5 +105,8 @@ class DummyUDF extends PythonFunction(
broadcastVars = null,
accumulator = null)
-class MyDummyPythonUDF
- extends UserDefinedPythonFunction(name = "dummyUDF", func = new DummyUDF, dataType = BooleanType)
+class MyDummyPythonUDF extends UserDefinedPythonFunction(
+ name = "dummyUDF",
+ func = new DummyUDF,
+ dataType = BooleanType,
+ vectorized = false)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org