You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/09/22 08:20:28 UTC

spark git commit: [SPARK-21190][PYSPARK] Python Vectorized UDFs

Repository: spark
Updated Branches:
  refs/heads/master 8f130ad40 -> 27fc536d9


[SPARK-21190][PYSPARK] Python Vectorized UDFs

This PR adds vectorized UDFs to the Python API

**Proposed API**
Introduce a flag to turn on vectorization for a defined UDF, for example:

```
pandas_udf(DoubleType())
def plus(a, b)
    return a + b
```
or

```
plus = pandas_udf(lambda a, b: a + b, DoubleType())
```
Usage is the same as normal UDFs

0-parameter UDFs
pandas_udf functions can declare an optional `**kwargs` and when evaluated, will contain a key "size" that will give the required length of the output.  For example:

```
pandas_udf(LongType())
def f0(**kwargs):
    return pd.Series(1).repeat(kwargs["size"])

df.select(f0())
```

Added new unit tests in pyspark.sql that are enabled if pyarrow and Pandas are available.

- [x] Fix support for promoted types with null values
- [ ] Discuss 0-param UDF API (use of kwargs)
- [x] Add tests for chained UDFs
- [ ] Discuss behavior when pyarrow not installed / enabled
- [ ] Cleanup pydoc and add user docs

Author: Bryan Cutler <cu...@gmail.com>
Author: Takuya UESHIN <ue...@databricks.com>

Closes #18659 from BryanCutler/arrow-vectorized-udfs-SPARK-21404.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/27fc536d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/27fc536d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/27fc536d

Branch: refs/heads/master
Commit: 27fc536d9a54eccef7d1cbbe2a6a008043d62ba4
Parents: 8f130ad
Author: Bryan Cutler <cu...@gmail.com>
Authored: Fri Sep 22 16:17:41 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Fri Sep 22 16:17:50 2017 +0800

----------------------------------------------------------------------
 .../org/apache/spark/api/python/PythonRDD.scala |  22 ++-
 python/pyspark/serializers.py                   |  65 +++++-
 python/pyspark/sql/functions.py                 |  49 +++--
 python/pyspark/sql/tests.py                     | 197 +++++++++++++++++++
 python/pyspark/sql/types.py                     |  27 +++
 python/pyspark/worker.py                        |  57 ++++--
 .../execution/python/ArrowEvalPythonExec.scala  |  61 ++++++
 .../execution/python/BatchEvalPythonExec.scala  | 193 ++++++------------
 .../sql/execution/python/EvalPythonExec.scala   | 142 +++++++++++++
 .../execution/python/ExtractPythonUDFs.scala    |  11 +-
 .../spark/sql/execution/python/PythonUDF.scala  |   3 +-
 .../python/UserDefinedPythonFunction.scala      |   5 +-
 .../python/BatchEvalPythonExecSuite.scala       |   7 +-
 13 files changed, 666 insertions(+), 173 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/27fc536d/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 3377101..86d0405 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -83,10 +83,23 @@ private[spark] case class PythonFunction(
  */
 private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction])
 
+/**
+ * Enumerate the type of command that will be sent to the Python worker
+ */
+private[spark] object PythonEvalType {
+  val NON_UDF = 0
+  val SQL_BATCHED_UDF = 1
+  val SQL_PANDAS_UDF = 2
+}
+
 private[spark] object PythonRunner {
   def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = {
     new PythonRunner(
-      Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuse_worker, false, Array(Array(0)))
+      Seq(ChainedPythonFunctions(Seq(func))),
+      bufferSize,
+      reuse_worker,
+      PythonEvalType.NON_UDF,
+      Array(Array(0)))
   }
 }
 
@@ -100,7 +113,7 @@ private[spark] class PythonRunner(
     funcs: Seq[ChainedPythonFunctions],
     bufferSize: Int,
     reuse_worker: Boolean,
-    isUDF: Boolean,
+    evalType: Int,
     argOffsets: Array[Array[Int]])
   extends Logging {
 
@@ -309,8 +322,8 @@ private[spark] class PythonRunner(
         }
         dataOut.flush()
         // Serialized command:
-        if (isUDF) {
-          dataOut.writeInt(1)
+        dataOut.writeInt(evalType)
+        if (evalType != PythonEvalType.NON_UDF) {
           dataOut.writeInt(funcs.length)
           funcs.zip(argOffsets).foreach { case (chained, offsets) =>
             dataOut.writeInt(offsets.length)
@@ -324,7 +337,6 @@ private[spark] class PythonRunner(
             }
           }
         } else {
-          dataOut.writeInt(0)
           val command = funcs.head.funcs.head.command
           dataOut.writeInt(command.length)
           dataOut.write(command)

http://git-wip-us.apache.org/repos/asf/spark/blob/27fc536d/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 660b19a..887c702 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -81,6 +81,12 @@ class SpecialLengths(object):
     NULL = -5
 
 
+class PythonEvalType(object):
+    NON_UDF = 0
+    SQL_BATCHED_UDF = 1
+    SQL_PANDAS_UDF = 2
+
+
 class Serializer(object):
 
     def dump_stream(self, iterator, stream):
@@ -187,8 +193,14 @@ class ArrowSerializer(FramedSerializer):
     Serializes an Arrow stream.
     """
 
-    def dumps(self, obj):
-        raise NotImplementedError
+    def dumps(self, batch):
+        import pyarrow as pa
+        import io
+        sink = io.BytesIO()
+        writer = pa.RecordBatchFileWriter(sink, batch.schema)
+        writer.write_batch(batch)
+        writer.close()
+        return sink.getvalue()
 
     def loads(self, obj):
         import pyarrow as pa
@@ -199,6 +211,55 @@ class ArrowSerializer(FramedSerializer):
         return "ArrowSerializer"
 
 
+class ArrowPandasSerializer(ArrowSerializer):
+    """
+    Serializes Pandas.Series as Arrow data.
+    """
+
+    def __init__(self):
+        super(ArrowPandasSerializer, self).__init__()
+
+    def dumps(self, series):
+        """
+        Make an ArrowRecordBatch from a Pandas Series and serialize. Input is a single series or
+        a list of series accompanied by an optional pyarrow type to coerce the data to.
+        """
+        import pyarrow as pa
+        # Make input conform to [(series1, type1), (series2, type2), ...]
+        if not isinstance(series, (list, tuple)) or \
+                (len(series) == 2 and isinstance(series[1], pa.DataType)):
+            series = [series]
+        series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)
+
+        # If a nullable integer series has been promoted to floating point with NaNs, need to cast
+        # NOTE: this is not necessary with Arrow >= 0.7
+        def cast_series(s, t):
+            if t is None or s.dtype == t.to_pandas_dtype():
+                return s
+            else:
+                return s.fillna(0).astype(t.to_pandas_dtype(), copy=False)
+
+        arrs = [pa.Array.from_pandas(cast_series(s, t), mask=s.isnull(), type=t) for s, t in series]
+        batch = pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])
+        return super(ArrowPandasSerializer, self).dumps(batch)
+
+    def loads(self, obj):
+        """
+        Deserialize an ArrowRecordBatch to an Arrow table and return as a list of pandas.Series
+        followed by a dictionary containing length of the loaded batches.
+        """
+        import pyarrow as pa
+        reader = pa.RecordBatchFileReader(pa.BufferReader(obj))
+        batches = [reader.get_batch(i) for i in xrange(reader.num_record_batches)]
+        # NOTE: a 0-parameter pandas_udf will produce an empty batch that can have num_rows set
+        num_rows = sum((batch.num_rows for batch in batches))
+        table = pa.Table.from_batches(batches)
+        return [c.to_pandas() for c in table.itercolumns()] + [{"length": num_rows}]
+
+    def __repr__(self):
+        return "ArrowPandasSerializer"
+
+
 class BatchedSerializer(Serializer):
 
     """

http://git-wip-us.apache.org/repos/asf/spark/blob/27fc536d/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 57068fb..46e3a85 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2044,7 +2044,7 @@ class UserDefinedFunction(object):
 
     .. versionadded:: 1.3
     """
-    def __init__(self, func, returnType, name=None):
+    def __init__(self, func, returnType, name=None, vectorized=False):
         if not callable(func):
             raise TypeError(
                 "Not a function or callable (__call__ is not defined): "
@@ -2058,6 +2058,7 @@ class UserDefinedFunction(object):
         self._name = name or (
             func.__name__ if hasattr(func, '__name__')
             else func.__class__.__name__)
+        self._vectorized = vectorized
 
     @property
     def returnType(self):
@@ -2089,7 +2090,7 @@ class UserDefinedFunction(object):
         wrapped_func = _wrap_function(sc, self.func, self.returnType)
         jdt = spark._jsparkSession.parseDataType(self.returnType.json())
         judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
-            self._name, wrapped_func, jdt)
+            self._name, wrapped_func, jdt, self._vectorized)
         return judf
 
     def __call__(self, *cols):
@@ -2123,6 +2124,22 @@ class UserDefinedFunction(object):
         return wrapper
 
 
+def _create_udf(f, returnType, vectorized):
+
+    def _udf(f, returnType=StringType(), vectorized=vectorized):
+        udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized)
+        return udf_obj._wrapped()
+
+    # decorator @udf, @udf(), @udf(dataType()), or similar with @pandas_udf
+    if f is None or isinstance(f, (str, DataType)):
+        # If DataType has been passed as a positional argument
+        # for decorator use it as a returnType
+        return_type = f or returnType
+        return functools.partial(_udf, returnType=return_type, vectorized=vectorized)
+    else:
+        return _udf(f=f, returnType=returnType, vectorized=vectorized)
+
+
 @since(1.3)
 def udf(f=None, returnType=StringType()):
     """Creates a :class:`Column` expression representing a user defined function (UDF).
@@ -2154,18 +2171,26 @@ def udf(f=None, returnType=StringType()):
     |         8|      JOHN DOE|          22|
     +----------+--------------+------------+
     """
-    def _udf(f, returnType=StringType()):
-        udf_obj = UserDefinedFunction(f, returnType)
-        return udf_obj._wrapped()
+    return _create_udf(f, returnType=returnType, vectorized=False)
 
-    # decorator @udf, @udf() or @udf(dataType())
-    if f is None or isinstance(f, (str, DataType)):
-        # If DataType has been passed as a positional argument
-        # for decorator use it as a returnType
-        return_type = f or returnType
-        return functools.partial(_udf, returnType=return_type)
+
+@since(2.3)
+def pandas_udf(f=None, returnType=StringType()):
+    """
+    Creates a :class:`Column` expression representing a user defined function (UDF) that accepts
+    `Pandas.Series` as input arguments and outputs a `Pandas.Series` of the same length.
+
+    :param f: python function if used as a standalone function
+    :param returnType: a :class:`pyspark.sql.types.DataType` object
+
+    # TODO: doctest
+    """
+    import inspect
+    # If function "f" does not define the optional kwargs, then wrap with a kwargs placeholder
+    if inspect.getargspec(f).keywords is None:
+        return _create_udf(lambda *a, **kwargs: f(*a), returnType=returnType, vectorized=True)
     else:
-        return _udf(f=f, returnType=returnType)
+        return _create_udf(f, returnType=returnType, vectorized=True)
 
 
 blacklist = ['map', 'since', 'ignore_unicode_prefix']

http://git-wip-us.apache.org/repos/asf/spark/blob/27fc536d/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 6e7ddf9..ab76c48 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3136,6 +3136,203 @@ class ArrowTests(ReusedPySparkTestCase):
         self.assertTrue(pdf.empty)
 
 
+@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
+class VectorizedUDFTests(ReusedPySparkTestCase):
+
+    @classmethod
+    def setUpClass(cls):
+        ReusedPySparkTestCase.setUpClass()
+        cls.spark = SparkSession(cls.sc)
+
+    @classmethod
+    def tearDownClass(cls):
+        ReusedPySparkTestCase.tearDownClass()
+        cls.spark.stop()
+
+    def test_vectorized_udf_basic(self):
+        from pyspark.sql.functions import pandas_udf, col
+        df = self.spark.range(10).select(
+            col('id').cast('string').alias('str'),
+            col('id').cast('int').alias('int'),
+            col('id').alias('long'),
+            col('id').cast('float').alias('float'),
+            col('id').cast('double').alias('double'),
+            col('id').cast('boolean').alias('bool'))
+        f = lambda x: x
+        str_f = pandas_udf(f, StringType())
+        int_f = pandas_udf(f, IntegerType())
+        long_f = pandas_udf(f, LongType())
+        float_f = pandas_udf(f, FloatType())
+        double_f = pandas_udf(f, DoubleType())
+        bool_f = pandas_udf(f, BooleanType())
+        res = df.select(str_f(col('str')), int_f(col('int')),
+                        long_f(col('long')), float_f(col('float')),
+                        double_f(col('double')), bool_f(col('bool')))
+        self.assertEquals(df.collect(), res.collect())
+
+    def test_vectorized_udf_null_boolean(self):
+        from pyspark.sql.functions import pandas_udf, col
+        data = [(True,), (True,), (None,), (False,)]
+        schema = StructType().add("bool", BooleanType())
+        df = self.spark.createDataFrame(data, schema)
+        bool_f = pandas_udf(lambda x: x, BooleanType())
+        res = df.select(bool_f(col('bool')))
+        self.assertEquals(df.collect(), res.collect())
+
+    def test_vectorized_udf_null_byte(self):
+        from pyspark.sql.functions import pandas_udf, col
+        data = [(None,), (2,), (3,), (4,)]
+        schema = StructType().add("byte", ByteType())
+        df = self.spark.createDataFrame(data, schema)
+        byte_f = pandas_udf(lambda x: x, ByteType())
+        res = df.select(byte_f(col('byte')))
+        self.assertEquals(df.collect(), res.collect())
+
+    def test_vectorized_udf_null_short(self):
+        from pyspark.sql.functions import pandas_udf, col
+        data = [(None,), (2,), (3,), (4,)]
+        schema = StructType().add("short", ShortType())
+        df = self.spark.createDataFrame(data, schema)
+        short_f = pandas_udf(lambda x: x, ShortType())
+        res = df.select(short_f(col('short')))
+        self.assertEquals(df.collect(), res.collect())
+
+    def test_vectorized_udf_null_int(self):
+        from pyspark.sql.functions import pandas_udf, col
+        data = [(None,), (2,), (3,), (4,)]
+        schema = StructType().add("int", IntegerType())
+        df = self.spark.createDataFrame(data, schema)
+        int_f = pandas_udf(lambda x: x, IntegerType())
+        res = df.select(int_f(col('int')))
+        self.assertEquals(df.collect(), res.collect())
+
+    def test_vectorized_udf_null_long(self):
+        from pyspark.sql.functions import pandas_udf, col
+        data = [(None,), (2,), (3,), (4,)]
+        schema = StructType().add("long", LongType())
+        df = self.spark.createDataFrame(data, schema)
+        long_f = pandas_udf(lambda x: x, LongType())
+        res = df.select(long_f(col('long')))
+        self.assertEquals(df.collect(), res.collect())
+
+    def test_vectorized_udf_null_float(self):
+        from pyspark.sql.functions import pandas_udf, col
+        data = [(3.0,), (5.0,), (-1.0,), (None,)]
+        schema = StructType().add("float", FloatType())
+        df = self.spark.createDataFrame(data, schema)
+        float_f = pandas_udf(lambda x: x, FloatType())
+        res = df.select(float_f(col('float')))
+        self.assertEquals(df.collect(), res.collect())
+
+    def test_vectorized_udf_null_double(self):
+        from pyspark.sql.functions import pandas_udf, col
+        data = [(3.0,), (5.0,), (-1.0,), (None,)]
+        schema = StructType().add("double", DoubleType())
+        df = self.spark.createDataFrame(data, schema)
+        double_f = pandas_udf(lambda x: x, DoubleType())
+        res = df.select(double_f(col('double')))
+        self.assertEquals(df.collect(), res.collect())
+
+    def test_vectorized_udf_null_string(self):
+        from pyspark.sql.functions import pandas_udf, col
+        data = [("foo",), (None,), ("bar",), ("bar",)]
+        schema = StructType().add("str", StringType())
+        df = self.spark.createDataFrame(data, schema)
+        str_f = pandas_udf(lambda x: x, StringType())
+        res = df.select(str_f(col('str')))
+        self.assertEquals(df.collect(), res.collect())
+
+    def test_vectorized_udf_zero_parameter(self):
+        from pyspark.sql.functions import pandas_udf
+        import pandas as pd
+        df = self.spark.range(10)
+        f0 = pandas_udf(lambda **kwargs: pd.Series(1).repeat(kwargs['length']), LongType())
+        res = df.select(f0())
+        self.assertEquals(df.select(lit(1)).collect(), res.collect())
+
+    def test_vectorized_udf_datatype_string(self):
+        from pyspark.sql.functions import pandas_udf, col
+        df = self.spark.range(10).select(
+            col('id').cast('string').alias('str'),
+            col('id').cast('int').alias('int'),
+            col('id').alias('long'),
+            col('id').cast('float').alias('float'),
+            col('id').cast('double').alias('double'),
+            col('id').cast('boolean').alias('bool'))
+        f = lambda x: x
+        str_f = pandas_udf(f, 'string')
+        int_f = pandas_udf(f, 'integer')
+        long_f = pandas_udf(f, 'long')
+        float_f = pandas_udf(f, 'float')
+        double_f = pandas_udf(f, 'double')
+        bool_f = pandas_udf(f, 'boolean')
+        res = df.select(str_f(col('str')), int_f(col('int')),
+                        long_f(col('long')), float_f(col('float')),
+                        double_f(col('double')), bool_f(col('bool')))
+        self.assertEquals(df.collect(), res.collect())
+
+    def test_vectorized_udf_complex(self):
+        from pyspark.sql.functions import pandas_udf, col, expr
+        df = self.spark.range(10).select(
+            col('id').cast('int').alias('a'),
+            col('id').cast('int').alias('b'),
+            col('id').cast('double').alias('c'))
+        add = pandas_udf(lambda x, y: x + y, IntegerType())
+        power2 = pandas_udf(lambda x: 2 ** x, IntegerType())
+        mul = pandas_udf(lambda x, y: x * y, DoubleType())
+        res = df.select(add(col('a'), col('b')), power2(col('a')), mul(col('b'), col('c')))
+        expected = df.select(expr('a + b'), expr('power(2, a)'), expr('b * c'))
+        self.assertEquals(expected.collect(), res.collect())
+
+    def test_vectorized_udf_exception(self):
+        from pyspark.sql.functions import pandas_udf, col
+        df = self.spark.range(10)
+        raise_exception = pandas_udf(lambda x: x * (1 / 0), LongType())
+        with QuietTest(self.sc):
+            with self.assertRaisesRegexp(Exception, 'division( or modulo)? by zero'):
+                df.select(raise_exception(col('id'))).collect()
+
+    def test_vectorized_udf_invalid_length(self):
+        from pyspark.sql.functions import pandas_udf, col
+        import pandas as pd
+        df = self.spark.range(10)
+        raise_exception = pandas_udf(lambda: pd.Series(1), LongType())
+        with QuietTest(self.sc):
+            with self.assertRaisesRegexp(
+                    Exception,
+                    'Result vector from pandas_udf was not the required length'):
+                df.select(raise_exception()).collect()
+
+    def test_vectorized_udf_mix_udf(self):
+        from pyspark.sql.functions import pandas_udf, udf, col
+        df = self.spark.range(10)
+        row_by_row_udf = udf(lambda x: x, LongType())
+        pd_udf = pandas_udf(lambda x: x, LongType())
+        with QuietTest(self.sc):
+            with self.assertRaisesRegexp(
+                    Exception,
+                    'Can not mix vectorized and non-vectorized UDFs'):
+                df.select(row_by_row_udf(col('id')), pd_udf(col('id'))).collect()
+
+    def test_vectorized_udf_chained(self):
+        from pyspark.sql.functions import pandas_udf, col
+        df = self.spark.range(10).toDF('x')
+        f = pandas_udf(lambda x: x + 1, LongType())
+        g = pandas_udf(lambda x: x - 1, LongType())
+        res = df.select(g(f(col('x'))))
+        self.assertEquals(df.collect(), res.collect())
+
+    def test_vectorized_udf_wrong_return_type(self):
+        from pyspark.sql.functions import pandas_udf, col
+        df = self.spark.range(10).toDF('x')
+        f = pandas_udf(lambda x: x * 1.0, StringType())
+        with QuietTest(self.sc):
+            with self.assertRaisesRegexp(
+                    Exception,
+                    'Invalid.*type.*string'):
+                df.select(f(col('x'))).collect()
+
+
 if __name__ == "__main__":
     from pyspark.sql.tests import *
     if xmlrunner:

http://git-wip-us.apache.org/repos/asf/spark/blob/27fc536d/python/pyspark/sql/types.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index aaf520f..ebdc11c 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1597,6 +1597,33 @@ register_input_converter(DatetimeConverter())
 register_input_converter(DateConverter())
 
 
+def toArrowType(dt):
+    """ Convert Spark data type to pyarrow type
+    """
+    import pyarrow as pa
+    if type(dt) == BooleanType:
+        arrow_type = pa.bool_()
+    elif type(dt) == ByteType:
+        arrow_type = pa.int8()
+    elif type(dt) == ShortType:
+        arrow_type = pa.int16()
+    elif type(dt) == IntegerType:
+        arrow_type = pa.int32()
+    elif type(dt) == LongType:
+        arrow_type = pa.int64()
+    elif type(dt) == FloatType:
+        arrow_type = pa.float32()
+    elif type(dt) == DoubleType:
+        arrow_type = pa.float64()
+    elif type(dt) == DecimalType:
+        arrow_type = pa.decimal(dt.precision, dt.scale)
+    elif type(dt) == StringType:
+        arrow_type = pa.string()
+    else:
+        raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
+    return arrow_type
+
+
 def _test():
     import doctest
     from pyspark.context import SparkContext

http://git-wip-us.apache.org/repos/asf/spark/blob/27fc536d/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index baaa3fe..0e35cf7 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -30,7 +30,9 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry
 from pyspark.taskcontext import TaskContext
 from pyspark.files import SparkFiles
 from pyspark.serializers import write_with_length, write_int, read_long, \
-    write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, BatchedSerializer
+    write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \
+    BatchedSerializer, ArrowPandasSerializer
+from pyspark.sql.types import toArrowType
 from pyspark import shuffle
 
 pickleSer = PickleSerializer()
@@ -58,9 +60,12 @@ def read_command(serializer, file):
     return command
 
 
-def chain(f, g):
-    """chain two function together """
-    return lambda *a: g(f(*a))
+def chain(f, g, eval_type):
+    """chain two functions together """
+    if eval_type == PythonEvalType.SQL_PANDAS_UDF:
+        return lambda *a, **kwargs: g(f(*a, **kwargs), **kwargs)
+    else:
+        return lambda *a: g(f(*a))
 
 
 def wrap_udf(f, return_type):
@@ -71,7 +76,21 @@ def wrap_udf(f, return_type):
         return lambda *a: f(*a)
 
 
-def read_single_udf(pickleSer, infile):
+def wrap_pandas_udf(f, return_type):
+    arrow_return_type = toArrowType(return_type)
+
+    def verify_result_length(*a):
+        kwargs = a[-1]
+        result = f(*a[:-1], **kwargs)
+        if len(result) != kwargs["length"]:
+            raise RuntimeError("Result vector from pandas_udf was not the required length: "
+                               "expected %d, got %d\nUse input vector length or kwargs['length']"
+                               % (kwargs["length"], len(result)))
+        return result, arrow_return_type
+    return lambda *a: verify_result_length(*a)
+
+
+def read_single_udf(pickleSer, infile, eval_type):
     num_arg = read_int(infile)
     arg_offsets = [read_int(infile) for i in range(num_arg)]
     row_func = None
@@ -80,17 +99,22 @@ def read_single_udf(pickleSer, infile):
         if row_func is None:
             row_func = f
         else:
-            row_func = chain(row_func, f)
+            row_func = chain(row_func, f, eval_type)
     # the last returnType will be the return type of UDF
-    return arg_offsets, wrap_udf(row_func, return_type)
+    if eval_type == PythonEvalType.SQL_PANDAS_UDF:
+        # A pandas_udf will take kwargs as the last argument
+        arg_offsets = arg_offsets + [-1]
+        return arg_offsets, wrap_pandas_udf(row_func, return_type)
+    else:
+        return arg_offsets, wrap_udf(row_func, return_type)
 
 
-def read_udfs(pickleSer, infile):
+def read_udfs(pickleSer, infile, eval_type):
     num_udfs = read_int(infile)
     udfs = {}
     call_udf = []
     for i in range(num_udfs):
-        arg_offsets, udf = read_single_udf(pickleSer, infile)
+        arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type)
         udfs['f%d' % i] = udf
         args = ["a[%d]" % o for o in arg_offsets]
         call_udf.append("f%d(%s)" % (i, ", ".join(args)))
@@ -102,7 +126,12 @@ def read_udfs(pickleSer, infile):
     mapper = eval(mapper_str, udfs)
 
     func = lambda _, it: map(mapper, it)
-    ser = BatchedSerializer(PickleSerializer(), 100)
+
+    if eval_type == PythonEvalType.SQL_PANDAS_UDF:
+        ser = ArrowPandasSerializer()
+    else:
+        ser = BatchedSerializer(PickleSerializer(), 100)
+
     # profiling is not supported for UDF
     return func, None, ser, ser
 
@@ -159,11 +188,11 @@ def main(infile, outfile):
                 _broadcastRegistry.pop(bid)
 
         _accumulatorRegistry.clear()
-        is_sql_udf = read_int(infile)
-        if is_sql_udf:
-            func, profiler, deserializer, serializer = read_udfs(pickleSer, infile)
-        else:
+        eval_type = read_int(infile)
+        if eval_type == PythonEvalType.NON_UDF:
             func, profiler, deserializer, serializer = read_command(pickleSer, infile)
+        else:
+            func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type)
 
         init_time = time.time()
 

http://git-wip-us.apache.org/repos/asf/spark/blob/27fc536d/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
new file mode 100644
index 0000000..f8bdc1e
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.TaskContext
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonRunner}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * A physical plan that evaluates a [[PythonUDF]],
+ */
+case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
+  extends EvalPythonExec(udfs, output, child) {
+
+  protected override def evaluate(
+      funcs: Seq[ChainedPythonFunctions],
+      bufferSize: Int,
+      reuseWorker: Boolean,
+      argOffsets: Array[Array[Int]],
+      iter: Iterator[InternalRow],
+      schema: StructType,
+      context: TaskContext): Iterator[InternalRow] = {
+    val inputIterator = ArrowConverters.toPayloadIterator(
+      iter, schema, conf.arrowMaxRecordsPerBatch, context).map(_.asPythonSerializable)
+
+    // Output iterator for results from Python.
+    val outputIterator = new PythonRunner(
+        funcs, bufferSize, reuseWorker, PythonEvalType.SQL_PANDAS_UDF, argOffsets)
+      .compute(inputIterator, context.partitionId(), context)
+
+    val outputRowIterator = ArrowConverters.fromPayloadIterator(
+      outputIterator.map(new ArrowPayload(_)), context)
+
+    // Verify that the output schema is correct
+    val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex
+      .map { case (attr, i) => attr.withName(s"_$i") })
+    assert(schemaOut.equals(outputRowIterator.schema),
+      s"Invalid schema from pandas_udf: expected $schemaOut, got ${outputRowIterator.schema}")
+
+    outputRowIterator
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/27fc536d/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
index 3e176e2..2978eac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
@@ -17,153 +17,78 @@
 
 package org.apache.spark.sql.execution.python
 
-import java.io.File
-
 import scala.collection.JavaConverters._
-import scala.collection.mutable.ArrayBuffer
 
 import net.razorvine.pickle.{Pickler, Unpickler}
 
-import org.apache.spark.{SparkEnv, TaskContext}
-import org.apache.spark.api.python.{ChainedPythonFunctions, PythonRunner}
-import org.apache.spark.rdd.RDD
+import org.apache.spark.TaskContext
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonRunner}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.types.{DataType, StructField, StructType}
-import org.apache.spark.util.Utils
-
+import org.apache.spark.sql.types.{StructField, StructType}
 
 /**
- * A physical plan that evaluates a [[PythonUDF]], one partition of tuples at a time.
- *
- * Python evaluation works by sending the necessary (projected) input data via a socket to an
- * external Python process, and combine the result from the Python process with the original row.
- *
- * For each row we send to Python, we also put it in a queue first. For each output row from Python,
- * we drain the queue to find the original input row. Note that if the Python process is way too
- * slow, this could lead to the queue growing unbounded and spill into disk when run out of memory.
- *
- * Here is a diagram to show how this works:
- *
- *            Downstream (for parent)
- *             /      \
- *            /     socket  (output of UDF)
- *           /         \
- *        RowQueue    Python
- *           \         /
- *            \     socket  (input of UDF)
- *             \     /
- *          upstream (from child)
- *
- * The rows sent to and received from Python are packed into batches (100 rows) and serialized,
- * there should be always some rows buffered in the socket or Python process, so the pulling from
- * RowQueue ALWAYS happened after pushing into it.
+ * A physical plan that evaluates a [[PythonUDF]]
  */
 case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
-  extends SparkPlan {
-
-  def children: Seq[SparkPlan] = child :: Nil
-
-  override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
-
-  private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
-    udf.children match {
-      case Seq(u: PythonUDF) =>
-        val (chained, children) = collectFunctions(u)
-        (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
-      case children =>
-        // There should not be any other UDFs, or the children can't be evaluated directly.
-        assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
-        (ChainedPythonFunctions(Seq(udf.func)), udf.children)
-    }
-  }
-
-  protected override def doExecute(): RDD[InternalRow] = {
-    val inputRDD = child.execute().map(_.copy())
-    val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
-    val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
-
-    inputRDD.mapPartitions { iter =>
-      EvaluatePython.registerPicklers()  // register pickler for Row
-
-      // The queue used to buffer input rows so we can drain it to
-      // combine input with output from Python.
-      val queue = HybridRowQueue(TaskContext.get().taskMemoryManager(),
-        new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length)
-      TaskContext.get().addTaskCompletionListener({ ctx =>
-        queue.close()
-      })
-
-      val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip
-
-      // flatten all the arguments
-      val allInputs = new ArrayBuffer[Expression]
-      val dataTypes = new ArrayBuffer[DataType]
-      val argOffsets = inputs.map { input =>
-        input.map { e =>
-          if (allInputs.exists(_.semanticEquals(e))) {
-            allInputs.indexWhere(_.semanticEquals(e))
-          } else {
-            allInputs += e
-            dataTypes += e.dataType
-            allInputs.length - 1
-          }
-        }.toArray
-      }.toArray
-      val projection = newMutableProjection(allInputs, child.output)
-      val schema = StructType(dataTypes.map(dt => StructField("", dt)))
-      val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython)
-
-      // enable memo iff we serialize the row with schema (schema and class should be memorized)
-      val pickle = new Pickler(needConversion)
-      // Input iterator to Python: input rows are grouped so we send them in batches to Python.
-      // For each row, add it to the queue.
-      val inputIterator = iter.map { inputRow =>
-        queue.add(inputRow.asInstanceOf[UnsafeRow])
-        val row = projection(inputRow)
-        if (needConversion) {
-          EvaluatePython.toJava(row, schema)
-        } else {
-          // fast path for these types that does not need conversion in Python
-          val fields = new Array[Any](row.numFields)
-          var i = 0
-          while (i < row.numFields) {
-            val dt = dataTypes(i)
-            fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
-            i += 1
-          }
-          fields
-        }
-      }.grouped(100).map(x => pickle.dumps(x.toArray))
-
-      val context = TaskContext.get()
-
-      // Output iterator for results from Python.
-      val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, argOffsets)
-        .compute(inputIterator, context.partitionId(), context)
-
-      val unpickle = new Unpickler
-      val mutableRow = new GenericInternalRow(1)
-      val joined = new JoinedRow
-      val resultType = if (udfs.length == 1) {
-        udfs.head.dataType
+  extends EvalPythonExec(udfs, output, child) {
+
+  protected override def evaluate(
+      funcs: Seq[ChainedPythonFunctions],
+      bufferSize: Int,
+      reuseWorker: Boolean,
+      argOffsets: Array[Array[Int]],
+      iter: Iterator[InternalRow],
+      schema: StructType,
+      context: TaskContext): Iterator[InternalRow] = {
+    EvaluatePython.registerPicklers()  // register pickler for Row
+
+    val dataTypes = schema.map(_.dataType)
+    val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython)
+
+    // enable memo iff we serialize the row with schema (schema and class should be memorized)
+    val pickle = new Pickler(needConversion)
+    // Input iterator to Python: input rows are grouped so we send them in batches to Python.
+    // For each row, add it to the queue.
+    val inputIterator = iter.map { row =>
+      if (needConversion) {
+        EvaluatePython.toJava(row, schema)
       } else {
-        StructType(udfs.map(u => StructField("", u.dataType, u.nullable)))
-      }
-      val resultProj = UnsafeProjection.create(output, output)
-      outputIterator.flatMap { pickedResult =>
-        val unpickledBatch = unpickle.loads(pickedResult)
-        unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
-      }.map { result =>
-        val row = if (udfs.length == 1) {
-          // fast path for single UDF
-          mutableRow(0) = EvaluatePython.fromJava(result, resultType)
-          mutableRow
-        } else {
-          EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow]
+        // fast path for these types that does not need conversion in Python
+        val fields = new Array[Any](row.numFields)
+        var i = 0
+        while (i < row.numFields) {
+          val dt = dataTypes(i)
+          fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
+          i += 1
         }
-        resultProj(joined(queue.remove(), row))
+        fields
+      }
+    }.grouped(100).map(x => pickle.dumps(x.toArray))
+
+    // Output iterator for results from Python.
+    val outputIterator = new PythonRunner(
+        funcs, bufferSize, reuseWorker, PythonEvalType.SQL_BATCHED_UDF, argOffsets)
+      .compute(inputIterator, context.partitionId(), context)
+
+    val unpickle = new Unpickler
+    val mutableRow = new GenericInternalRow(1)
+    val resultType = if (udfs.length == 1) {
+      udfs.head.dataType
+    } else {
+      StructType(udfs.map(u => StructField("", u.dataType, u.nullable)))
+    }
+    outputIterator.flatMap { pickedResult =>
+      val unpickledBatch = unpickle.loads(pickedResult)
+      unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
+    }.map { result =>
+      if (udfs.length == 1) {
+        // fast path for single UDF
+        mutableRow(0) = EvaluatePython.fromJava(result, resultType)
+        mutableRow
+      } else {
+        EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow]
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/27fc536d/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala
new file mode 100644
index 0000000..860dc78
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io.File
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.api.python.ChainedPythonFunctions
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.types.{DataType, StructField, StructType}
+import org.apache.spark.util.Utils
+
+
+/**
+ * A physical plan that evaluates a [[PythonUDF]], one partition of tuples at a time.
+ *
+ * Python evaluation works by sending the necessary (projected) input data via a socket to an
+ * external Python process, and combine the result from the Python process with the original row.
+ *
+ * For each row we send to Python, we also put it in a queue first. For each output row from Python,
+ * we drain the queue to find the original input row. Note that if the Python process is way too
+ * slow, this could lead to the queue growing unbounded and spill into disk when run out of memory.
+ *
+ * Here is a diagram to show how this works:
+ *
+ *            Downstream (for parent)
+ *             /      \
+ *            /     socket  (output of UDF)
+ *           /         \
+ *        RowQueue    Python
+ *           \         /
+ *            \     socket  (input of UDF)
+ *             \     /
+ *          upstream (from child)
+ *
+ * The rows sent to and received from Python are packed into batches (100 rows) and serialized,
+ * there should be always some rows buffered in the socket or Python process, so the pulling from
+ * RowQueue ALWAYS happened after pushing into it.
+ */
+abstract class EvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
+  extends SparkPlan {
+
+  def children: Seq[SparkPlan] = child :: Nil
+
+  override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
+
+  private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
+    udf.children match {
+      case Seq(u: PythonUDF) =>
+        val (chained, children) = collectFunctions(u)
+        (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
+      case children =>
+        // There should not be any other UDFs, or the children can't be evaluated directly.
+        assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
+        (ChainedPythonFunctions(Seq(udf.func)), udf.children)
+    }
+  }
+
+  protected def evaluate(
+      funcs: Seq[ChainedPythonFunctions],
+      bufferSize: Int,
+      reuseWorker: Boolean,
+      argOffsets: Array[Array[Int]],
+      iter: Iterator[InternalRow],
+      schema: StructType,
+      context: TaskContext): Iterator[InternalRow]
+
+  protected override def doExecute(): RDD[InternalRow] = {
+    val inputRDD = child.execute().map(_.copy())
+    val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
+    val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
+
+    inputRDD.mapPartitions { iter =>
+      val context = TaskContext.get()
+
+      // The queue used to buffer input rows so we can drain it to
+      // combine input with output from Python.
+      val queue = HybridRowQueue(context.taskMemoryManager(),
+        new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length)
+      context.addTaskCompletionListener { ctx =>
+        queue.close()
+      }
+
+      val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip
+
+      // flatten all the arguments
+      val allInputs = new ArrayBuffer[Expression]
+      val dataTypes = new ArrayBuffer[DataType]
+      val argOffsets = inputs.map { input =>
+        input.map { e =>
+          if (allInputs.exists(_.semanticEquals(e))) {
+            allInputs.indexWhere(_.semanticEquals(e))
+          } else {
+            allInputs += e
+            dataTypes += e.dataType
+            allInputs.length - 1
+          }
+        }.toArray
+      }.toArray
+      val projection = newMutableProjection(allInputs, child.output)
+      val schema = StructType(dataTypes.zipWithIndex.map { case (dt, i) =>
+        StructField(s"_$i", dt)
+      })
+
+      // Add rows to queue to join later with the result.
+      val projectedRowIter = iter.map { inputRow =>
+        queue.add(inputRow.asInstanceOf[UnsafeRow])
+        projection(inputRow)
+      }
+
+      val outputRowIterator = evaluate(
+        pyFuncs, bufferSize, reuseWorker, argOffsets, projectedRowIter, schema, context)
+
+      val joined = new JoinedRow
+      val resultProj = UnsafeProjection.create(output, output)
+
+      outputRowIterator.map { outputRow =>
+        resultProj(joined(queue.remove(), outputRow))
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/27fc536d/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index 882a5ce..fec456d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -138,7 +138,16 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
           val resultAttrs = udfs.zipWithIndex.map { case (u, i) =>
             AttributeReference(s"pythonUDF$i", u.dataType)()
           }
-          val evaluation = BatchEvalPythonExec(validUdfs, child.output ++ resultAttrs, child)
+
+          val evaluation = validUdfs.partition(_.vectorized) match {
+            case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty =>
+              ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child)
+            case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty =>
+              BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child)
+            case _ =>
+              throw new IllegalArgumentException("Can not mix vectorized and non-vectorized UDFs")
+          }
+
           attributeMap ++= validUdfs.zip(resultAttrs)
           evaluation
         } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/27fc536d/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
index 7ebbdb9..84a6d9e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
@@ -28,7 +28,8 @@ case class PythonUDF(
     name: String,
     func: PythonFunction,
     dataType: DataType,
-    children: Seq[Expression])
+    children: Seq[Expression],
+    vectorized: Boolean)
   extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression {
 
   override def toString: String = s"$name(${children.mkString(", ")})"

http://git-wip-us.apache.org/repos/asf/spark/blob/27fc536d/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index 0d39c8f..a30a80a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -28,10 +28,11 @@ import org.apache.spark.sql.types.DataType
 case class UserDefinedPythonFunction(
     name: String,
     func: PythonFunction,
-    dataType: DataType) {
+    dataType: DataType,
+    vectorized: Boolean) {
 
   def builder(e: Seq[Expression]): PythonUDF = {
-    PythonUDF(name, func, dataType, e)
+    PythonUDF(name, func, dataType, e, vectorized)
   }
 
   /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */

http://git-wip-us.apache.org/repos/asf/spark/blob/27fc536d/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
index bbd9484..153e6e1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
@@ -105,5 +105,8 @@ class DummyUDF extends PythonFunction(
   broadcastVars = null,
   accumulator = null)
 
-class MyDummyPythonUDF
-  extends UserDefinedPythonFunction(name = "dummyUDF", func = new DummyUDF, dataType = BooleanType)
+class MyDummyPythonUDF extends UserDefinedPythonFunction(
+  name = "dummyUDF",
+  func = new DummyUDF,
+  dataType = BooleanType,
+  vectorized = false)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org