You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2017/10/10 22:32:07 UTC

spark git commit: [SPARK-20396][SQL][PYSPARK] groupby().apply() with pandas udf

Repository: spark
Updated Branches:
  refs/heads/master 2028e5a82 -> bfc7e1fe1


[SPARK-20396][SQL][PYSPARK] groupby().apply() with pandas udf

## What changes were proposed in this pull request?

This PR adds an apply() function on df.groupby(). apply() takes a pandas udf that is a transformation on `pandas.DataFrame` -> `pandas.DataFrame`.

Static schema
-------------------
```
schema = df.schema

pandas_udf(schema)
def normalize(df):
    df = df.assign(v1 = (df.v1 - df.v1.mean()) / df.v1.std()
    return df

df.groupBy('id').apply(normalize)
```
Dynamic schema
-----------------------
**This use case is removed from the PR and we will discuss this as a follow up. See discussion https://github.com/apache/spark/pull/18732#pullrequestreview-66583248**

Another example to use pd.DataFrame dtypes as output schema of the udf:

```
sample_df = df.filter(df.id == 1).toPandas()

def foo(df):
      ret = # Some transformation on the input pd.DataFrame
      return ret

foo_udf = pandas_udf(foo, foo(sample_df).dtypes)

df.groupBy('id').apply(foo_udf)
```
In interactive use case, user usually have a sample pd.DataFrame to test function `foo` in their notebook. Having been able to use `foo(sample_df).dtypes` frees user from specifying the output schema of `foo`.

Design doc: https://github.com/icexelloss/spark/blob/pandas-udf-doc/docs/pyspark-pandas-udf.md

## How was this patch tested?
* Added GroupbyApplyTest

Author: Li Jin <ic...@gmail.com>
Author: Takuya UESHIN <ue...@databricks.com>
Author: Bryan Cutler <cu...@gmail.com>

Closes #18732 from icexelloss/groupby-apply-SPARK-20396.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bfc7e1fe
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bfc7e1fe
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bfc7e1fe

Branch: refs/heads/master
Commit: bfc7e1fe1ad5f9777126f2941e29bbe51ea5da7c
Parents: 2028e5a
Author: Li Jin <ic...@gmail.com>
Authored: Wed Oct 11 07:32:01 2017 +0900
Committer: hyukjinkwon <gu...@gmail.com>
Committed: Wed Oct 11 07:32:01 2017 +0900

----------------------------------------------------------------------
 python/pyspark/sql/dataframe.py                 |   6 +-
 python/pyspark/sql/functions.py                 |  98 +++++++++---
 python/pyspark/sql/group.py                     |  88 ++++++++++-
 python/pyspark/sql/tests.py                     | 157 ++++++++++++++++++-
 python/pyspark/sql/types.py                     |   2 +-
 python/pyspark/worker.py                        |  35 +++--
 .../sql/catalyst/optimizer/Optimizer.scala      |   2 +
 .../plans/logical/pythonLogicalOperators.scala  |  39 +++++
 .../spark/sql/RelationalGroupedDataset.scala    |  36 ++++-
 .../spark/sql/execution/SparkStrategies.scala   |   2 +
 .../execution/python/ArrowEvalPythonExec.scala  |  39 ++++-
 .../execution/python/ArrowPythonRunner.scala    |  15 +-
 .../execution/python/ExtractPythonUDFs.scala    |   8 +-
 .../python/FlatMapGroupsInPandasExec.scala      | 103 ++++++++++++
 14 files changed, 561 insertions(+), 69 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bfc7e1fe/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index fe69e58..2d59622 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1227,7 +1227,7 @@ class DataFrame(object):
         """
         jgd = self._jdf.groupBy(self._jcols(*cols))
         from pyspark.sql.group import GroupedData
-        return GroupedData(jgd, self.sql_ctx)
+        return GroupedData(jgd, self)
 
     @since(1.4)
     def rollup(self, *cols):
@@ -1248,7 +1248,7 @@ class DataFrame(object):
         """
         jgd = self._jdf.rollup(self._jcols(*cols))
         from pyspark.sql.group import GroupedData
-        return GroupedData(jgd, self.sql_ctx)
+        return GroupedData(jgd, self)
 
     @since(1.4)
     def cube(self, *cols):
@@ -1271,7 +1271,7 @@ class DataFrame(object):
         """
         jgd = self._jdf.cube(self._jcols(*cols))
         from pyspark.sql.group import GroupedData
-        return GroupedData(jgd, self.sql_ctx)
+        return GroupedData(jgd, self)
 
     @since(1.3)
     def agg(self, *exprs):

http://git-wip-us.apache.org/repos/asf/spark/blob/bfc7e1fe/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index b45a59d..9bc12c3 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2058,7 +2058,7 @@ class UserDefinedFunction(object):
         self._name = name or (
             func.__name__ if hasattr(func, '__name__')
             else func.__class__.__name__)
-        self._vectorized = vectorized
+        self.vectorized = vectorized
 
     @property
     def returnType(self):
@@ -2090,7 +2090,7 @@ class UserDefinedFunction(object):
         wrapped_func = _wrap_function(sc, self.func, self.returnType)
         jdt = spark._jsparkSession.parseDataType(self.returnType.json())
         judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
-            self._name, wrapped_func, jdt, self._vectorized)
+            self._name, wrapped_func, jdt, self.vectorized)
         return judf
 
     def __call__(self, *cols):
@@ -2118,8 +2118,10 @@ class UserDefinedFunction(object):
         wrapper.__name__ = self._name
         wrapper.__module__ = (self.func.__module__ if hasattr(self.func, '__module__')
                               else self.func.__class__.__module__)
+
         wrapper.func = self.func
         wrapper.returnType = self.returnType
+        wrapper.vectorized = self.vectorized
 
         return wrapper
 
@@ -2129,8 +2131,12 @@ def _create_udf(f, returnType, vectorized):
     def _udf(f, returnType=StringType(), vectorized=vectorized):
         if vectorized:
             import inspect
-            if len(inspect.getargspec(f).args) == 0:
-                raise NotImplementedError("0-parameter pandas_udfs are not currently supported")
+            argspec = inspect.getargspec(f)
+            if len(argspec.args) == 0 and argspec.varargs is None:
+                raise ValueError(
+                    "0-arg pandas_udfs are not supported. "
+                    "Instead, create a 1-arg pandas_udf and ignore the arg in your function."
+                )
         udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized)
         return udf_obj._wrapped()
 
@@ -2146,7 +2152,7 @@ def _create_udf(f, returnType, vectorized):
 
 @since(1.3)
 def udf(f=None, returnType=StringType()):
-    """Creates a :class:`Column` expression representing a user defined function (UDF).
+    """Creates a user defined function (UDF).
 
     .. note:: The user-defined functions must be deterministic. Due to optimization,
         duplicate invocations may be eliminated or the function may even be invoked more times than
@@ -2181,30 +2187,70 @@ def udf(f=None, returnType=StringType()):
 @since(2.3)
 def pandas_udf(f=None, returnType=StringType()):
     """
-    Creates a :class:`Column` expression representing a user defined function (UDF) that accepts
-    `Pandas.Series` as input arguments and outputs a `Pandas.Series` of the same length.
+    Creates a vectorized user defined function (UDF).
 
-    :param f: python function if used as a standalone function
+    :param f: user-defined function. A python function if used as a standalone function
     :param returnType: a :class:`pyspark.sql.types.DataType` object
 
-    >>> from pyspark.sql.types import IntegerType, StringType
-    >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType())
-    >>> @pandas_udf(returnType=StringType())
-    ... def to_upper(s):
-    ...     return s.str.upper()
-    ...
-    >>> @pandas_udf(returnType="integer")
-    ... def add_one(x):
-    ...     return x + 1
-    ...
-    >>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age"))
-    >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\
-    ...     .show()  # doctest: +SKIP
-    +----------+--------------+------------+
-    |slen(name)|to_upper(name)|add_one(age)|
-    +----------+--------------+------------+
-    |         8|      JOHN DOE|          22|
-    +----------+--------------+------------+
+    The user-defined function can define one of the following transformations:
+
+    1. One or more `pandas.Series` -> A `pandas.Series`
+
+       This udf is used with :meth:`pyspark.sql.DataFrame.withColumn` and
+       :meth:`pyspark.sql.DataFrame.select`.
+       The returnType should be a primitive data type, e.g., `DoubleType()`.
+       The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`.
+
+       >>> from pyspark.sql.types import IntegerType, StringType
+       >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType())
+       >>> @pandas_udf(returnType=StringType())
+       ... def to_upper(s):
+       ...     return s.str.upper()
+       ...
+       >>> @pandas_udf(returnType="integer")
+       ... def add_one(x):
+       ...     return x + 1
+       ...
+       >>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age"))
+       >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\
+       ...     .show()  # doctest: +SKIP
+       +----------+--------------+------------+
+       |slen(name)|to_upper(name)|add_one(age)|
+       +----------+--------------+------------+
+       |         8|      JOHN DOE|          22|
+       +----------+--------------+------------+
+
+    2. A `pandas.DataFrame` -> A `pandas.DataFrame`
+
+       This udf is only used with :meth:`pyspark.sql.GroupedData.apply`.
+       The returnType should be a :class:`StructType` describing the schema of the returned
+       `pandas.DataFrame`.
+
+       >>> df = spark.createDataFrame(
+       ...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
+       ...     ("id", "v"))
+       >>> @pandas_udf(returnType=df.schema)
+       ... def normalize(pdf):
+       ...     v = pdf.v
+       ...     return pdf.assign(v=(v - v.mean()) / v.std())
+       >>> df.groupby('id').apply(normalize).show()  # doctest: +SKIP
+       +---+-------------------+
+       | id|                  v|
+       +---+-------------------+
+       |  1|-0.7071067811865475|
+       |  1| 0.7071067811865475|
+       |  2|-0.8320502943378437|
+       |  2|-0.2773500981126146|
+       |  2| 1.1094003924504583|
+       +---+-------------------+
+
+       .. note:: This type of udf cannot be used with functions such as `withColumn` or `select`
+                 because it defines a `DataFrame` transformation rather than a `Column`
+                 transformation.
+
+       .. seealso:: :meth:`pyspark.sql.GroupedData.apply`
+
+    .. note:: The user-defined function must be deterministic.
     """
     return _create_udf(f, returnType=returnType, vectorized=True)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/bfc7e1fe/python/pyspark/sql/group.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py
index f2092f9..817d0bc 100644
--- a/python/pyspark/sql/group.py
+++ b/python/pyspark/sql/group.py
@@ -54,9 +54,10 @@ class GroupedData(object):
     .. versionadded:: 1.3
     """
 
-    def __init__(self, jgd, sql_ctx):
+    def __init__(self, jgd, df):
         self._jgd = jgd
-        self.sql_ctx = sql_ctx
+        self._df = df
+        self.sql_ctx = df.sql_ctx
 
     @ignore_unicode_prefix
     @since(1.3)
@@ -170,7 +171,7 @@ class GroupedData(object):
     @since(1.6)
     def pivot(self, pivot_col, values=None):
         """
-        Pivots a column of the current [[DataFrame]] and perform the specified aggregation.
+        Pivots a column of the current :class:`DataFrame` and perform the specified aggregation.
         There are two versions of pivot function: one that requires the caller to specify the list
         of distinct values to pivot on, and one that does not. The latter is more concise but less
         efficient, because Spark needs to first compute the list of distinct values internally.
@@ -192,7 +193,85 @@ class GroupedData(object):
             jgd = self._jgd.pivot(pivot_col)
         else:
             jgd = self._jgd.pivot(pivot_col, values)
-        return GroupedData(jgd, self.sql_ctx)
+        return GroupedData(jgd, self._df)
+
+    @since(2.3)
+    def apply(self, udf):
+        """
+        Maps each group of the current :class:`DataFrame` using a pandas udf and returns the result
+        as a `DataFrame`.
+
+        The user-defined function should take a `pandas.DataFrame` and return another
+        `pandas.DataFrame`. For each group, all columns are passed together as a `pandas.DataFrame`
+        to the user-function and the returned `pandas.DataFrame`s are combined as a
+        :class:`DataFrame`.
+        The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the
+        returnType of the pandas udf.
+
+        This function does not support partial aggregation, and requires shuffling all the data in
+        the :class:`DataFrame`.
+
+        :param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf`
+
+        >>> from pyspark.sql.functions import pandas_udf
+        >>> df = spark.createDataFrame(
+        ...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
+        ...     ("id", "v"))
+        >>> @pandas_udf(returnType=df.schema)
+        ... def normalize(pdf):
+        ...     v = pdf.v
+        ...     return pdf.assign(v=(v - v.mean()) / v.std())
+        >>> df.groupby('id').apply(normalize).show()  # doctest: +SKIP
+        +---+-------------------+
+        | id|                  v|
+        +---+-------------------+
+        |  1|-0.7071067811865475|
+        |  1| 0.7071067811865475|
+        |  2|-0.8320502943378437|
+        |  2|-0.2773500981126146|
+        |  2| 1.1094003924504583|
+        +---+-------------------+
+
+        .. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
+
+        """
+        from pyspark.sql.functions import pandas_udf
+
+        # Columns are special because hasattr always return True
+        if isinstance(udf, Column) or not hasattr(udf, 'func') or not udf.vectorized:
+            raise ValueError("The argument to apply must be a pandas_udf")
+        if not isinstance(udf.returnType, StructType):
+            raise ValueError("The returnType of the pandas_udf must be a StructType")
+
+        df = self._df
+        func = udf.func
+        returnType = udf.returnType
+
+        # The python executors expects the function to use pd.Series as input and output
+        # So we to create a wrapper function that turns that to a pd.DataFrame before passing
+        # down to the user function, then turn the result pd.DataFrame back into pd.Series
+        columns = df.columns
+
+        def wrapped(*cols):
+            from pyspark.sql.types import to_arrow_type
+            import pandas as pd
+            result = func(pd.concat(cols, axis=1, keys=columns))
+            if not isinstance(result, pd.DataFrame):
+                raise TypeError("Return type of the user-defined function should be "
+                                "Pandas.DataFrame, but is {}".format(type(result)))
+            if not len(result.columns) == len(returnType):
+                raise RuntimeError(
+                    "Number of columns of the returned Pandas.DataFrame "
+                    "doesn't match specified schema. "
+                    "Expected: {} Actual: {}".format(len(returnType), len(result.columns)))
+            arrow_return_types = (to_arrow_type(field.dataType) for field in returnType)
+            return [(result[result.columns[i]], arrow_type)
+                    for i, arrow_type in enumerate(arrow_return_types)]
+
+        wrapped_udf_obj = pandas_udf(wrapped, returnType)
+        udf_column = wrapped_udf_obj(*[df[col] for col in df.columns])
+        jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
+        return DataFrame(jdf, self.sql_ctx)
 
 
 def _test():
@@ -206,6 +285,7 @@ def _test():
         .getOrCreate()
     sc = spark.sparkContext
     globs['sc'] = sc
+    globs['spark'] = spark
     globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \
         .toDF(StructType([StructField('age', IntegerType()),
                           StructField('name', StringType())]))

http://git-wip-us.apache.org/repos/asf/spark/blob/bfc7e1fe/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index a59378b..bac2ef8 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3256,17 +3256,17 @@ class VectorizedUDFTests(ReusedPySparkTestCase):
 
     def test_vectorized_udf_zero_parameter(self):
         from pyspark.sql.functions import pandas_udf
-        error_str = '0-parameter pandas_udfs.*not.*supported'
+        error_str = '0-arg pandas_udfs.*not.*supported'
         with QuietTest(self.sc):
-            with self.assertRaisesRegexp(NotImplementedError, error_str):
+            with self.assertRaisesRegexp(ValueError, error_str):
                 pandas_udf(lambda: 1, LongType())
 
-            with self.assertRaisesRegexp(NotImplementedError, error_str):
+            with self.assertRaisesRegexp(ValueError, error_str):
                 @pandas_udf
                 def zero_no_type():
                     return 1
 
-            with self.assertRaisesRegexp(NotImplementedError, error_str):
+            with self.assertRaisesRegexp(ValueError, error_str):
                 @pandas_udf(LongType())
                 def zero_with_type():
                     return 1
@@ -3348,7 +3348,7 @@ class VectorizedUDFTests(ReusedPySparkTestCase):
         df = self.spark.range(10)
         f = pandas_udf(lambda x: x * 1.0, StringType())
         with QuietTest(self.sc):
-            with self.assertRaisesRegexp(Exception, 'Invalid.*type.*string'):
+            with self.assertRaisesRegexp(Exception, 'Invalid.*type'):
                 df.select(f(col('id'))).collect()
 
     def test_vectorized_udf_return_scalar(self):
@@ -3356,7 +3356,7 @@ class VectorizedUDFTests(ReusedPySparkTestCase):
         df = self.spark.range(10)
         f = pandas_udf(lambda x: 1.0, DoubleType())
         with QuietTest(self.sc):
-            with self.assertRaisesRegexp(Exception, 'Return.*type.*pandas_udf.*Series'):
+            with self.assertRaisesRegexp(Exception, 'Return.*type.*Series'):
                 df.select(f(col('id'))).collect()
 
     def test_vectorized_udf_decorator(self):
@@ -3376,6 +3376,151 @@ class VectorizedUDFTests(ReusedPySparkTestCase):
         res = df.select(f(col('id')))
         self.assertEquals(df.collect(), res.collect())
 
+    def test_vectorized_udf_varargs(self):
+        from pyspark.sql.functions import pandas_udf, col
+        df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))
+        f = pandas_udf(lambda *v: v[0], LongType())
+        res = df.select(f(col('id')))
+        self.assertEquals(df.collect(), res.collect())
+
+
+@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
+class GroupbyApplyTests(ReusedPySparkTestCase):
+    @classmethod
+    def setUpClass(cls):
+        ReusedPySparkTestCase.setUpClass()
+        cls.spark = SparkSession(cls.sc)
+
+    @classmethod
+    def tearDownClass(cls):
+        ReusedPySparkTestCase.tearDownClass()
+        cls.spark.stop()
+
+    def assertFramesEqual(self, expected, result):
+        msg = ("DataFrames are not equal: " +
+               ("\n\nExpected:\n%s\n%s" % (expected, expected.dtypes)) +
+               ("\n\nResult:\n%s\n%s" % (result, result.dtypes)))
+        self.assertTrue(expected.equals(result), msg=msg)
+
+    @property
+    def data(self):
+        from pyspark.sql.functions import array, explode, col, lit
+        return self.spark.range(10).toDF('id') \
+            .withColumn("vs", array([lit(i) for i in range(20, 30)])) \
+            .withColumn("v", explode(col('vs'))).drop('vs')
+
+    def test_simple(self):
+        from pyspark.sql.functions import pandas_udf
+        df = self.data
+
+        foo_udf = pandas_udf(
+            lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
+            StructType(
+                [StructField('id', LongType()),
+                 StructField('v', IntegerType()),
+                 StructField('v1', DoubleType()),
+                 StructField('v2', LongType())]))
+
+        result = df.groupby('id').apply(foo_udf).sort('id').toPandas()
+        expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True)
+        self.assertFramesEqual(expected, result)
+
+    def test_decorator(self):
+        from pyspark.sql.functions import pandas_udf
+        df = self.data
+
+        @pandas_udf(StructType(
+            [StructField('id', LongType()),
+             StructField('v', IntegerType()),
+             StructField('v1', DoubleType()),
+             StructField('v2', LongType())]))
+        def foo(pdf):
+            return pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id)
+
+        result = df.groupby('id').apply(foo).sort('id').toPandas()
+        expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True)
+        self.assertFramesEqual(expected, result)
+
+    def test_coerce(self):
+        from pyspark.sql.functions import pandas_udf
+        df = self.data
+
+        foo = pandas_udf(
+            lambda pdf: pdf,
+            StructType([StructField('id', LongType()), StructField('v', DoubleType())]))
+
+        result = df.groupby('id').apply(foo).sort('id').toPandas()
+        expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True)
+        expected = expected.assign(v=expected.v.astype('float64'))
+        self.assertFramesEqual(expected, result)
+
+    def test_complex_groupby(self):
+        from pyspark.sql.functions import pandas_udf, col
+        df = self.data
+
+        @pandas_udf(StructType(
+            [StructField('id', LongType()),
+             StructField('v', IntegerType()),
+             StructField('norm', DoubleType())]))
+        def normalize(pdf):
+            v = pdf.v
+            return pdf.assign(norm=(v - v.mean()) / v.std())
+
+        result = df.groupby(col('id') % 2 == 0).apply(normalize).sort('id', 'v').toPandas()
+        pdf = df.toPandas()
+        expected = pdf.groupby(pdf['id'] % 2 == 0).apply(normalize.func)
+        expected = expected.sort_values(['id', 'v']).reset_index(drop=True)
+        expected = expected.assign(norm=expected.norm.astype('float64'))
+        self.assertFramesEqual(expected, result)
+
+    def test_empty_groupby(self):
+        from pyspark.sql.functions import pandas_udf, col
+        df = self.data
+
+        @pandas_udf(StructType(
+            [StructField('id', LongType()),
+             StructField('v', IntegerType()),
+             StructField('norm', DoubleType())]))
+        def normalize(pdf):
+            v = pdf.v
+            return pdf.assign(norm=(v - v.mean()) / v.std())
+
+        result = df.groupby().apply(normalize).sort('id', 'v').toPandas()
+        pdf = df.toPandas()
+        expected = normalize.func(pdf)
+        expected = expected.sort_values(['id', 'v']).reset_index(drop=True)
+        expected = expected.assign(norm=expected.norm.astype('float64'))
+        self.assertFramesEqual(expected, result)
+
+    def test_wrong_return_type(self):
+        from pyspark.sql.functions import pandas_udf
+        df = self.data
+
+        foo = pandas_udf(
+            lambda pdf: pdf,
+            StructType([StructField('id', LongType()), StructField('v', StringType())]))
+
+        with QuietTest(self.sc):
+            with self.assertRaisesRegexp(Exception, 'Invalid.*type'):
+                df.groupby('id').apply(foo).sort('id').toPandas()
+
+    def test_wrong_args(self):
+        from pyspark.sql.functions import udf, pandas_udf, sum
+        df = self.data
+
+        with QuietTest(self.sc):
+            with self.assertRaisesRegexp(ValueError, 'pandas_udf'):
+                df.groupby('id').apply(lambda x: x)
+            with self.assertRaisesRegexp(ValueError, 'pandas_udf'):
+                df.groupby('id').apply(udf(lambda x: x, DoubleType()))
+            with self.assertRaisesRegexp(ValueError, 'pandas_udf'):
+                df.groupby('id').apply(sum(df.v))
+            with self.assertRaisesRegexp(ValueError, 'pandas_udf'):
+                df.groupby('id').apply(df.v + 1)
+            with self.assertRaisesRegexp(ValueError, 'returnType'):
+                df.groupby('id').apply(pandas_udf(lambda x: x, DoubleType()))
+
+
 if __name__ == "__main__":
     from pyspark.sql.tests import *
     if xmlrunner:

http://git-wip-us.apache.org/repos/asf/spark/blob/bfc7e1fe/python/pyspark/sql/types.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index ebdc11c..f65273d 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1597,7 +1597,7 @@ register_input_converter(DatetimeConverter())
 register_input_converter(DateConverter())
 
 
-def toArrowType(dt):
+def to_arrow_type(dt):
     """ Convert Spark data type to pyarrow type
     """
     import pyarrow as pa

http://git-wip-us.apache.org/repos/asf/spark/blob/bfc7e1fe/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 4e24789..eb6d486 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -32,7 +32,7 @@ from pyspark.files import SparkFiles
 from pyspark.serializers import write_with_length, write_int, read_long, \
     write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \
     BatchedSerializer, ArrowStreamPandasSerializer
-from pyspark.sql.types import toArrowType
+from pyspark.sql.types import to_arrow_type, StructType
 from pyspark import shuffle
 
 pickleSer = PickleSerializer()
@@ -74,17 +74,28 @@ def wrap_udf(f, return_type):
 
 
 def wrap_pandas_udf(f, return_type):
-    arrow_return_type = toArrowType(return_type)
-
-    def verify_result_length(*a):
-        result = f(*a)
-        if not hasattr(result, "__len__"):
-            raise TypeError("Return type of pandas_udf should be a Pandas.Series")
-        if len(result) != len(a[0]):
-            raise RuntimeError("Result vector from pandas_udf was not the required length: "
-                               "expected %d, got %d" % (len(a[0]), len(result)))
-        return result
-    return lambda *a: (verify_result_length(*a), arrow_return_type)
+    # If the return_type is a StructType, it indicates this is a groupby apply udf,
+    # and has already been wrapped under apply(), otherwise, it's a vectorized column udf.
+    # We can distinguish these two by return type because in groupby apply, we always specify
+    # returnType as a StructType, and in vectorized column udf, StructType is not supported.
+    #
+    # TODO: Look into refactoring use of StructType to be more flexible for future pandas_udfs
+    if isinstance(return_type, StructType):
+        return lambda *a: f(*a)
+    else:
+        arrow_return_type = to_arrow_type(return_type)
+
+        def verify_result_length(*a):
+            result = f(*a)
+            if not hasattr(result, "__len__"):
+                raise TypeError("Return type of the user-defined functon should be "
+                                "Pandas.Series, but is {}".format(type(result)))
+            if len(result) != len(a[0]):
+                raise RuntimeError("Result vector from pandas_udf was not the required length: "
+                                   "expected %d, got %d" % (len(a[0]), len(result)))
+            return result
+
+        return lambda *a: (verify_result_length(*a), arrow_return_type)
 
 
 def read_single_udf(pickleSer, infile, eval_type):

http://git-wip-us.apache.org/repos/asf/spark/blob/bfc7e1fe/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index bc2d4a8..d829e01 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -452,6 +452,8 @@ object ColumnPruning extends Rule[LogicalPlan] {
     // Prunes the unused columns from child of Aggregate/Expand/Generate
     case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
       a.copy(child = prunedChild(child, a.references))
+    case f @ FlatMapGroupsInPandas(_, _, _, child) if (child.outputSet -- f.references).nonEmpty =>
+      f.copy(child = prunedChild(child, f.references))
     case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty =>
       e.copy(child = prunedChild(child, e.references))
     case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty =>

http://git-wip-us.apache.org/repos/asf/spark/blob/bfc7e1fe/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
new file mode 100644
index 0000000..8abab24
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression}
+
+/**
+ * FlatMap groups using an udf: pandas.Dataframe -> pandas.DataFrame.
+ * This is used by DataFrame.groupby().apply().
+ */
+case class FlatMapGroupsInPandas(
+  groupingAttributes: Seq[Attribute],
+  functionExpr: Expression,
+  output: Seq[Attribute],
+  child: LogicalPlan) extends UnaryNode {
+  /**
+   * This is needed because output attributes are considered `references` when
+   * passed through the constructor.
+   *
+   * Without this, catalyst will complain that output attributes are missing
+   * from the input.
+   */
+  override val producedAttributes = AttributeSet(output)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/bfc7e1fe/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 147b549..cd0ac1f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -27,12 +27,12 @@ import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, FlatMapGroupsInR, Pivot}
+import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.util.usePrettyExpression
 import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
+import org.apache.spark.sql.execution.python.PythonUDF
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.NumericType
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{NumericType, StructField, StructType}
 
 /**
  * A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]],
@@ -435,6 +435,36 @@ class RelationalGroupedDataset protected[sql](
           df.logicalPlan.output,
           df.logicalPlan))
   }
+
+  /**
+   * Applies a vectorized python user-defined function to each group of data.
+   * The user-defined function defines a transformation: `pandas.DataFrame` -> `pandas.DataFrame`.
+   * For each group, all elements in the group are passed as a `pandas.DataFrame` and the results
+   * for all groups are combined into a new [[DataFrame]].
+   *
+   * This function does not support partial aggregation, and requires shuffling all the data in
+   * the [[DataFrame]].
+   *
+   * This function uses Apache Arrow as serialization format between Java executors and Python
+   * workers.
+   */
+  private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = {
+    require(expr.vectorized, "Must pass a vectorized python udf")
+    require(expr.dataType.isInstanceOf[StructType],
+      "The returnType of the vectorized python udf must be a StructType")
+
+    val groupingNamedExpressions = groupingExprs.map {
+      case ne: NamedExpression => ne
+      case other => Alias(other, other.toString)()
+    }
+    val groupingAttributes = groupingNamedExpressions.map(_.toAttribute)
+    val child = df.logicalPlan
+    val project = Project(groupingNamedExpressions ++ child.output, child)
+    val output = expr.dataType.asInstanceOf[StructType].toAttributes
+    val plan = FlatMapGroupsInPandas(groupingAttributes, expr, output, project)
+
+    Dataset.ofRows(df.sparkSession, plan)
+  }
 }
 
 private[sql] object RelationalGroupedDataset {

http://git-wip-us.apache.org/repos/asf/spark/blob/bfc7e1fe/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 92eaab5..4cdcc73 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -392,6 +392,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
       case logical.FlatMapGroupsInR(f, p, b, is, os, key, value, grouping, data, objAttr, child) =>
         execution.FlatMapGroupsInRExec(f, p, b, is, os, key, value, grouping,
           data, objAttr, planLater(child)) :: Nil
+      case logical.FlatMapGroupsInPandas(grouping, func, output, child) =>
+        execution.python.FlatMapGroupsInPandasExec(grouping, func, output, planLater(child)) :: Nil
       case logical.MapElements(f, _, _, objAttr, child) =>
         execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil
       case logical.AppendColumns(f, _, _, in, out, child) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/bfc7e1fe/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
index f7e8cbe..8189618 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
@@ -27,6 +27,35 @@ import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.types.StructType
 
 /**
+ * Grouped a iterator into batches.
+ * This is similar to iter.grouped but returns Iterator[T] instead of Seq[T].
+ * This is necessary because sometimes we cannot hold reference of input rows
+ * because the some input rows are mutable and can be reused.
+ */
+private class BatchIterator[T](iter: Iterator[T], batchSize: Int)
+  extends Iterator[Iterator[T]] {
+
+  override def hasNext: Boolean = iter.hasNext
+
+  override def next(): Iterator[T] = {
+    new Iterator[T] {
+      var count = 0
+
+      override def hasNext: Boolean = iter.hasNext && count < batchSize
+
+      override def next(): T = {
+        if (!hasNext) {
+          Iterator.empty.next()
+        } else {
+          count += 1
+          iter.next()
+        }
+      }
+    }
+  }
+}
+
+/**
  * A physical plan that evaluates a [[PythonUDF]],
  */
 case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
@@ -44,14 +73,18 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
     val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex
       .map { case (attr, i) => attr.withName(s"_$i") })
 
+    val batchSize = conf.arrowMaxRecordsPerBatch
+    // DO NOT use iter.grouped(). See BatchIterator.
+    val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter)
+
     val columnarBatchIter = new ArrowPythonRunner(
-        funcs, conf.arrowMaxRecordsPerBatch, bufferSize, reuseWorker,
+        funcs, bufferSize, reuseWorker,
         PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema)
-      .compute(iter, context.partitionId(), context)
+      .compute(batchIter, context.partitionId(), context)
 
     new Iterator[InternalRow] {
 
-      var currentIter = if (columnarBatchIter.hasNext) {
+      private var currentIter = if (columnarBatchIter.hasNext) {
         val batch = columnarBatchIter.next()
         assert(schemaOut.equals(batch.schema),
           s"Invalid schema from pandas_udf: expected $schemaOut, got ${batch.schema}")

http://git-wip-us.apache.org/repos/asf/spark/blob/bfc7e1fe/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index bbad9d6..f6c03c4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -39,19 +39,18 @@ import org.apache.spark.util.Utils
  */
 class ArrowPythonRunner(
     funcs: Seq[ChainedPythonFunctions],
-    batchSize: Int,
     bufferSize: Int,
     reuseWorker: Boolean,
     evalType: Int,
     argOffsets: Array[Array[Int]],
     schema: StructType)
-  extends BasePythonRunner[InternalRow, ColumnarBatch](
+  extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](
     funcs, bufferSize, reuseWorker, evalType, argOffsets) {
 
   protected override def newWriterThread(
       env: SparkEnv,
       worker: Socket,
-      inputIterator: Iterator[InternalRow],
+      inputIterator: Iterator[Iterator[InternalRow]],
       partitionIndex: Int,
       context: TaskContext): WriterThread = {
     new WriterThread(env, worker, inputIterator, partitionIndex, context) {
@@ -82,12 +81,12 @@ class ArrowPythonRunner(
 
         Utils.tryWithSafeFinally {
           while (inputIterator.hasNext) {
-            var rowCount = 0
-            while (inputIterator.hasNext && (batchSize <= 0 || rowCount < batchSize)) {
-              val row = inputIterator.next()
-              arrowWriter.write(row)
-              rowCount += 1
+            val nextBatch = inputIterator.next()
+
+            while (nextBatch.hasNext) {
+              arrowWriter.write(nextBatch.next())
             }
+
             arrowWriter.finish()
             writer.writeBatch()
             arrowWriter.reset()

http://git-wip-us.apache.org/repos/asf/spark/blob/bfc7e1fe/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index fec456d..e3f952e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -24,8 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution
-import org.apache.spark.sql.execution.{FilterExec, SparkPlan}
+import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan}
 
 
 /**
@@ -111,6 +110,9 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
   }
 
   def apply(plan: SparkPlan): SparkPlan = plan transformUp {
+    // FlatMapGroupsInPandas can be evaluated directly in python worker
+    // Therefore we don't need to extract the UDFs
+    case plan: FlatMapGroupsInPandasExec => plan
     case plan: SparkPlan => extract(plan)
   }
 
@@ -169,7 +171,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
       val newPlan = extract(rewritten)
       if (newPlan.output != plan.output) {
         // Trim away the new UDF value if it was only used for filtering or something.
-        execution.ProjectExec(plan.output, newPlan)
+        ProjectExec(plan.output, newPlan)
       } else {
         newPlan
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/bfc7e1fe/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
new file mode 100644
index 0000000..b996b5b
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.TaskContext
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
+import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandas]]
+ *
+ * Rows in each group are passed to the Python worker as an Arrow record batch.
+ * The Python worker turns the record batch to a `pandas.DataFrame`, invoke the
+ * user-defined function, and passes the resulting `pandas.DataFrame`
+ * as an Arrow record batch. Finally, each record batch is turned to
+ * Iterator[InternalRow] using ColumnarBatch.
+ *
+ * Note on memory usage:
+ * Both the Python worker and the Java executor need to have enough memory to
+ * hold the largest group. The memory on the Java side is used to construct the
+ * record batch (off heap memory). The memory on the Python side is used for
+ * holding the `pandas.DataFrame`. It's possible to further split one group into
+ * multiple record batches to reduce the memory footprint on the Java side, this
+ * is left as future work.
+ */
+case class FlatMapGroupsInPandasExec(
+    groupingAttributes: Seq[Attribute],
+    func: Expression,
+    output: Seq[Attribute],
+    child: SparkPlan)
+  extends UnaryExecNode {
+
+  private val pandasFunction = func.asInstanceOf[PythonUDF].func
+
+  override def outputPartitioning: Partitioning = child.outputPartitioning
+
+  override def producedAttributes: AttributeSet = AttributeSet(output)
+
+  override def requiredChildDistribution: Seq[Distribution] = {
+    if (groupingAttributes.isEmpty) {
+      AllTuples :: Nil
+    } else {
+      ClusteredDistribution(groupingAttributes) :: Nil
+    }
+  }
+
+  override def requiredChildOrdering: Seq[Seq[SortOrder]] =
+    Seq(groupingAttributes.map(SortOrder(_, Ascending)))
+
+  override protected def doExecute(): RDD[InternalRow] = {
+    val inputRDD = child.execute()
+
+    val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
+    val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
+    val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
+    val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray)
+    val schema = StructType(child.schema.drop(groupingAttributes.length))
+
+    inputRDD.mapPartitionsInternal { iter =>
+      val grouped = if (groupingAttributes.isEmpty) {
+        Iterator(iter)
+      } else {
+        val groupedIter = GroupedIterator(iter, groupingAttributes, child.output)
+        val dropGrouping =
+          UnsafeProjection.create(child.output.drop(groupingAttributes.length), child.output)
+        groupedIter.map {
+          case (_, groupedRowIter) => groupedRowIter.map(dropGrouping)
+        }
+      }
+
+      val context = TaskContext.get()
+
+      val columnarBatchIter = new ArrowPythonRunner(
+        chainedFunc, bufferSize, reuseWorker,
+        PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema)
+        .compute(grouped, context.partitionId(), context)
+
+      columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output))
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org