You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ue...@apache.org on 2018/01/23 05:11:41 UTC

spark git commit: [SPARK-22274][PYTHON][SQL] User-defined aggregation functions with pandas udf (full shuffle)

Repository: spark
Updated Branches:
  refs/heads/master 51eb75026 -> b2ce17b4c


[SPARK-22274][PYTHON][SQL] User-defined aggregation functions with pandas udf (full shuffle)

## What changes were proposed in this pull request?

Add support for using pandas UDFs with groupby().agg().

This PR introduces a new type of pandas UDF - group aggregate pandas UDF. This type of UDF defines a transformation of multiple pandas Series -> a scalar value. Group aggregate pandas UDFs can be used with groupby().agg(). Note group aggregate pandas UDF doesn't support partial aggregation, i.e., a full shuffle is required.

This PR doesn't support group aggregate pandas UDFs that return ArrayType, StructType or MapType. Support for these types is left for future PR.

## How was this patch tested?

GroupbyAggPandasUDFTests

Author: Li Jin <ic...@gmail.com>

Closes #19872 from icexelloss/SPARK-22274-groupby-agg.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b2ce17b4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b2ce17b4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b2ce17b4

Branch: refs/heads/master
Commit: b2ce17b4c9fea58140a57ca1846b2689b15c0d61
Parents: 51eb750
Author: Li Jin <ic...@gmail.com>
Authored: Tue Jan 23 14:11:30 2018 +0900
Committer: Takuya UESHIN <ue...@databricks.com>
Committed: Tue Jan 23 14:11:30 2018 +0900

----------------------------------------------------------------------
 .../apache/spark/api/python/PythonRunner.scala  |   2 +
 python/pyspark/rdd.py                           |   1 +
 python/pyspark/sql/functions.py                 |  36 +-
 python/pyspark/sql/group.py                     |  33 +-
 python/pyspark/sql/tests.py                     | 486 ++++++++++++++++++-
 python/pyspark/sql/udf.py                       |  13 +-
 python/pyspark/worker.py                        |  22 +-
 .../sql/catalyst/analysis/CheckAnalysis.scala   |  14 +-
 .../sql/catalyst/expressions/PythonUDF.scala    |  64 +++
 .../spark/sql/catalyst/planning/patterns.scala  |  12 +-
 .../spark/sql/RelationalGroupedDataset.scala    |   1 -
 .../spark/sql/execution/SparkStrategies.scala   |  29 +-
 .../python/AggregateInPandasExec.scala          | 155 ++++++
 .../execution/python/ExtractPythonUDFs.scala    |  16 +-
 .../spark/sql/execution/python/PythonUDF.scala  |  41 --
 .../python/UserDefinedPythonFunction.scala      |   2 +-
 16 files changed, 829 insertions(+), 98 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b2ce17b4/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 1ec0e71..29148a7 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -39,12 +39,14 @@ private[spark] object PythonEvalType {
 
   val SQL_PANDAS_SCALAR_UDF = 200
   val SQL_PANDAS_GROUP_MAP_UDF = 201
+  val SQL_PANDAS_GROUP_AGG_UDF = 202
 
   def toString(pythonEvalType: Int): String = pythonEvalType match {
     case NON_UDF => "NON_UDF"
     case SQL_BATCHED_UDF => "SQL_BATCHED_UDF"
     case SQL_PANDAS_SCALAR_UDF => "SQL_PANDAS_SCALAR_UDF"
     case SQL_PANDAS_GROUP_MAP_UDF => "SQL_PANDAS_GROUP_MAP_UDF"
+    case SQL_PANDAS_GROUP_AGG_UDF => "SQL_PANDAS_GROUP_AGG_UDF"
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b2ce17b4/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 1b39155..6b018c3 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -70,6 +70,7 @@ class PythonEvalType(object):
 
     SQL_PANDAS_SCALAR_UDF = 200
     SQL_PANDAS_GROUP_MAP_UDF = 201
+    SQL_PANDAS_GROUP_AGG_UDF = 202
 
 
 def portable_hash(x):

http://git-wip-us.apache.org/repos/asf/spark/blob/b2ce17b4/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 961b326..a291c9b 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2089,6 +2089,8 @@ class PandasUDFType(object):
 
     GROUP_MAP = PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF
 
+    GROUP_AGG = PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF
+
 
 @since(1.3)
 def udf(f=None, returnType=StringType()):
@@ -2159,7 +2161,7 @@ def pandas_udf(f=None, returnType=None, functionType=None):
     1. SCALAR
 
        A scalar UDF defines a transformation: One or more `pandas.Series` -> A `pandas.Series`.
-       The returnType should be a primitive data type, e.g., `DoubleType()`.
+       The returnType should be a primitive data type, e.g., :class:`DoubleType`.
        The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`.
 
        Scalar UDFs are used with :meth:`pyspark.sql.DataFrame.withColumn` and
@@ -2221,6 +2223,35 @@ def pandas_udf(f=None, returnType=None, functionType=None):
 
        .. seealso:: :meth:`pyspark.sql.GroupedData.apply`
 
+    3. GROUP_AGG
+
+       A group aggregate UDF defines a transformation: One or more `pandas.Series` -> A scalar
+       The `returnType` should be a primitive data type, e.g., :class:`DoubleType`.
+       The returned scalar can be either a python primitive type, e.g., `int` or `float`
+       or a numpy data type, e.g., `numpy.int64` or `numpy.float64`.
+
+       :class:`ArrayType`, :class:`MapType` and :class:`StructType` are currently not supported as
+       output types.
+
+       Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg`
+
+       >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
+       >>> df = spark.createDataFrame(
+       ...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
+       ...     ("id", "v"))
+       >>> @pandas_udf("double", PandasUDFType.GROUP_AGG)  # doctest: +SKIP
+       ... def mean_udf(v):
+       ...     return v.mean()
+       >>> df.groupby("id").agg(mean_udf(df['v'])).show()  # doctest: +SKIP
+       +---+-----------+
+       | id|mean_udf(v)|
+       +---+-----------+
+       |  1|        1.5|
+       |  2|        6.0|
+       +---+-----------+
+
+       .. seealso:: :meth:`pyspark.sql.GroupedData.agg`
+
     .. note:: The user-defined functions are considered deterministic by default. Due to
         optimization, duplicate invocations may be eliminated or the function may even be invoked
         more times than it is present in the query. If your function is not deterministic, call
@@ -2267,7 +2298,8 @@ def pandas_udf(f=None, returnType=None, functionType=None):
         raise ValueError("Invalid returnType: returnType can not be None")
 
     if eval_type not in [PythonEvalType.SQL_PANDAS_SCALAR_UDF,
-                         PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF]:
+                         PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF,
+                         PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF]:
         raise ValueError("Invalid functionType: "
                          "functionType must be one the values from PandasUDFType")
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b2ce17b4/python/pyspark/sql/group.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py
index 22061b8..f90a909 100644
--- a/python/pyspark/sql/group.py
+++ b/python/pyspark/sql/group.py
@@ -65,13 +65,27 @@ class GroupedData(object):
     def agg(self, *exprs):
         """Compute aggregates and returns the result as a :class:`DataFrame`.
 
-        The available aggregate functions are `avg`, `max`, `min`, `sum`, `count`.
+        The available aggregate functions can be:
+
+        1. built-in aggregation functions, such as `avg`, `max`, `min`, `sum`, `count`
+
+        2. group aggregate pandas UDFs, created with :func:`pyspark.sql.functions.pandas_udf`
+
+           .. note:: There is no partial aggregation with group aggregate UDFs, i.e.,
+               a full shuffle is required. Also, all the data of a group will be loaded into
+               memory, so the user should be aware of the potential OOM risk if data is skewed
+               and certain groups are too large to fit in memory.
+
+           .. seealso:: :func:`pyspark.sql.functions.pandas_udf`
 
         If ``exprs`` is a single :class:`dict` mapping from string to string, then the key
         is the column to perform aggregation on, and the value is the aggregate function.
 
         Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions.
 
+        .. note:: Built-in aggregation functions and group aggregate pandas UDFs cannot be mixed
+            in a single call to this function.
+
         :param exprs: a dict mapping from column name (string) to aggregate functions (string),
             or a list of :class:`Column`.
 
@@ -82,6 +96,13 @@ class GroupedData(object):
         >>> from pyspark.sql import functions as F
         >>> sorted(gdf.agg(F.min(df.age)).collect())
         [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)]
+
+        >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
+        >>> @pandas_udf('int', PandasUDFType.GROUP_AGG)  # doctest: +SKIP
+        ... def min_udf(v):
+        ...     return v.min()
+        >>> sorted(gdf.agg(min_udf(df.age)).collect())  # doctest: +SKIP
+        [Row(name=u'Alice', min_udf(age)=2), Row(name=u'Bob', min_udf(age)=5)]
         """
         assert exprs, "exprs should not be empty"
         if len(exprs) == 1 and isinstance(exprs[0], dict):
@@ -204,16 +225,18 @@ class GroupedData(object):
 
         The user-defined function should take a `pandas.DataFrame` and return another
         `pandas.DataFrame`. For each group, all columns are passed together as a `pandas.DataFrame`
-        to the user-function and the returned `pandas.DataFrame`s are combined as a
+        to the user-function and the returned `pandas.DataFrame` are combined as a
         :class:`DataFrame`.
+
         The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the
         returnType of the pandas udf.
 
-        This function does not support partial aggregation, and requires shuffling all the data in
-        the :class:`DataFrame`.
+        .. note:: This function requires a full shuffle. all the data of a group will be loaded
+            into memory, so the user should be aware of the potential OOM risk if data is skewed
+            and certain groups are too large to fit in memory.
 
         :param udf: a group map user-defined function returned by
-            :meth:`pyspark.sql.functions.pandas_udf`.
+            :func:`pyspark.sql.functions.pandas_udf`.
 
         >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
         >>> df = spark.createDataFrame(

http://git-wip-us.apache.org/repos/asf/spark/blob/b2ce17b4/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 4fee2ec..84e8eec 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -197,6 +197,12 @@ class ReusedSQLTestCase(ReusedPySparkTestCase):
         ReusedPySparkTestCase.tearDownClass()
         cls.spark.stop()
 
+    def assertPandasEqual(self, expected, result):
+        msg = ("DataFrames are not equal: " +
+               "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) +
+               "\n\nResult:\n%s\n%s" % (result, result.dtypes))
+        self.assertTrue(expected.equals(result), msg=msg)
+
 
 class DataTypeTests(unittest.TestCase):
     # regression test for SPARK-6055
@@ -3371,12 +3377,6 @@ class ArrowTests(ReusedSQLTestCase):
         time.tzset()
         ReusedSQLTestCase.tearDownClass()
 
-    def assertFramesEqual(self, df_with_arrow, df_without):
-        msg = ("DataFrame from Arrow is not equal" +
-               ("\n\nWith Arrow:\n%s\n%s" % (df_with_arrow, df_with_arrow.dtypes)) +
-               ("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes)))
-        self.assertTrue(df_without.equals(df_with_arrow), msg=msg)
-
     def create_pandas_data_frame(self):
         import pandas as pd
         import numpy as np
@@ -3414,7 +3414,7 @@ class ArrowTests(ReusedSQLTestCase):
     def test_toPandas_arrow_toggle(self):
         df = self.spark.createDataFrame(self.data, schema=self.schema)
         pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
-        self.assertFramesEqual(pdf_arrow, pdf)
+        self.assertPandasEqual(pdf_arrow, pdf)
 
     def test_toPandas_respect_session_timezone(self):
         df = self.spark.createDataFrame(self.data, schema=self.schema)
@@ -3425,11 +3425,11 @@ class ArrowTests(ReusedSQLTestCase):
             self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
             try:
                 pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df)
-                self.assertFramesEqual(pdf_arrow_la, pdf_la)
+                self.assertPandasEqual(pdf_arrow_la, pdf_la)
             finally:
                 self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")
             pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df)
-            self.assertFramesEqual(pdf_arrow_ny, pdf_ny)
+            self.assertPandasEqual(pdf_arrow_ny, pdf_ny)
 
             self.assertFalse(pdf_ny.equals(pdf_la))
 
@@ -3439,7 +3439,7 @@ class ArrowTests(ReusedSQLTestCase):
                 if isinstance(field.dataType, TimestampType):
                     pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz(
                         pdf_la_corrected[field.name], timezone)
-            self.assertFramesEqual(pdf_ny, pdf_la_corrected)
+            self.assertPandasEqual(pdf_ny, pdf_la_corrected)
         finally:
             self.spark.conf.set("spark.sql.session.timeZone", orig_tz)
 
@@ -3447,7 +3447,7 @@ class ArrowTests(ReusedSQLTestCase):
         pdf = self.create_pandas_data_frame()
         df = self.spark.createDataFrame(self.data, schema=self.schema)
         pdf_arrow = df.toPandas()
-        self.assertFramesEqual(pdf_arrow, pdf)
+        self.assertPandasEqual(pdf_arrow, pdf)
 
     def test_filtered_frame(self):
         df = self.spark.range(3).toDF("i")
@@ -3505,7 +3505,7 @@ class ArrowTests(ReusedSQLTestCase):
         df = self.spark.createDataFrame(pdf, schema=self.schema)
         self.assertEquals(self.schema, df.schema)
         pdf_arrow = df.toPandas()
-        self.assertFramesEqual(pdf_arrow, pdf)
+        self.assertPandasEqual(pdf_arrow, pdf)
 
     def test_createDataFrame_with_incorrect_schema(self):
         pdf = self.create_pandas_data_frame()
@@ -3717,7 +3717,7 @@ class PandasUDFTests(ReusedSQLTestCase):
 
 
 @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
-class VectorizedUDFTests(ReusedSQLTestCase):
+class ScalarPandasUDF(ReusedSQLTestCase):
 
     @classmethod
     def setUpClass(cls):
@@ -4196,13 +4196,7 @@ class VectorizedUDFTests(ReusedSQLTestCase):
 
 
 @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
-class GroupbyApplyTests(ReusedSQLTestCase):
-
-    def assertFramesEqual(self, expected, result):
-        msg = ("DataFrames are not equal: " +
-               ("\n\nExpected:\n%s\n%s" % (expected, expected.dtypes)) +
-               ("\n\nResult:\n%s\n%s" % (result, result.dtypes)))
-        self.assertTrue(expected.equals(result), msg=msg)
+class GroupbyApplyPandasUDFTests(ReusedSQLTestCase):
 
     @property
     def data(self):
@@ -4227,7 +4221,7 @@ class GroupbyApplyTests(ReusedSQLTestCase):
 
         result = df.groupby('id').apply(foo_udf).sort('id').toPandas()
         expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True)
-        self.assertFramesEqual(expected, result)
+        self.assertPandasEqual(expected, result)
 
     def test_register_group_map_udf(self):
         from pyspark.sql.functions import pandas_udf, PandasUDFType
@@ -4251,7 +4245,7 @@ class GroupbyApplyTests(ReusedSQLTestCase):
 
         result = df.groupby('id').apply(foo).sort('id').toPandas()
         expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True)
-        self.assertFramesEqual(expected, result)
+        self.assertPandasEqual(expected, result)
 
     def test_coerce(self):
         from pyspark.sql.functions import pandas_udf, PandasUDFType
@@ -4266,7 +4260,7 @@ class GroupbyApplyTests(ReusedSQLTestCase):
         result = df.groupby('id').apply(foo).sort('id').toPandas()
         expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True)
         expected = expected.assign(v=expected.v.astype('float64'))
-        self.assertFramesEqual(expected, result)
+        self.assertPandasEqual(expected, result)
 
     def test_complex_groupby(self):
         from pyspark.sql.functions import pandas_udf, col, PandasUDFType
@@ -4285,7 +4279,7 @@ class GroupbyApplyTests(ReusedSQLTestCase):
         expected = pdf.groupby(pdf['id'] % 2 == 0).apply(normalize.func)
         expected = expected.sort_values(['id', 'v']).reset_index(drop=True)
         expected = expected.assign(norm=expected.norm.astype('float64'))
-        self.assertFramesEqual(expected, result)
+        self.assertPandasEqual(expected, result)
 
     def test_empty_groupby(self):
         from pyspark.sql.functions import pandas_udf, col, PandasUDFType
@@ -4304,7 +4298,7 @@ class GroupbyApplyTests(ReusedSQLTestCase):
         expected = normalize.func(pdf)
         expected = expected.sort_values(['id', 'v']).reset_index(drop=True)
         expected = expected.assign(norm=expected.norm.astype('float64'))
-        self.assertFramesEqual(expected, result)
+        self.assertPandasEqual(expected, result)
 
     def test_datatype_string(self):
         from pyspark.sql.functions import pandas_udf, PandasUDFType
@@ -4318,7 +4312,7 @@ class GroupbyApplyTests(ReusedSQLTestCase):
 
         result = df.groupby('id').apply(foo_udf).sort('id').toPandas()
         expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True)
-        self.assertFramesEqual(expected, result)
+        self.assertPandasEqual(expected, result)
 
     def test_wrong_return_type(self):
         from pyspark.sql.functions import pandas_udf, PandasUDFType
@@ -4370,6 +4364,446 @@ class GroupbyApplyTests(ReusedSQLTestCase):
                 df.groupby('id').apply(f).collect()
 
 
+@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
+class GroupbyAggPandasUDFTests(ReusedSQLTestCase):
+
+    @property
+    def data(self):
+        from pyspark.sql.functions import array, explode, col, lit
+        return self.spark.range(10).toDF('id') \
+            .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \
+            .withColumn("v", explode(col('vs'))) \
+            .drop('vs') \
+            .withColumn('w', lit(1.0))
+
+    @property
+    def python_plus_one(self):
+        from pyspark.sql.functions import udf
+
+        @udf('double')
+        def plus_one(v):
+            assert isinstance(v, (int, float))
+            return v + 1
+        return plus_one
+
+    @property
+    def pandas_scalar_plus_two(self):
+        import pandas as pd
+        from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+        @pandas_udf('double', PandasUDFType.SCALAR)
+        def plus_two(v):
+            assert isinstance(v, pd.Series)
+            return v + 2
+        return plus_two
+
+    @property
+    def pandas_agg_mean_udf(self):
+        from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+        @pandas_udf('double', PandasUDFType.GROUP_AGG)
+        def avg(v):
+            return v.mean()
+        return avg
+
+    @property
+    def pandas_agg_sum_udf(self):
+        from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+        @pandas_udf('double', PandasUDFType.GROUP_AGG)
+        def sum(v):
+            return v.sum()
+        return sum
+
+    @property
+    def pandas_agg_weighted_mean_udf(self):
+        import numpy as np
+        from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+        @pandas_udf('double', PandasUDFType.GROUP_AGG)
+        def weighted_mean(v, w):
+            return np.average(v, weights=w)
+        return weighted_mean
+
+    def test_manual(self):
+        df = self.data
+        sum_udf = self.pandas_agg_sum_udf
+        mean_udf = self.pandas_agg_mean_udf
+
+        result1 = df.groupby('id').agg(sum_udf(df.v), mean_udf(df.v)).sort('id')
+        expected1 = self.spark.createDataFrame(
+            [[0, 245.0, 24.5],
+             [1, 255.0, 25.5],
+             [2, 265.0, 26.5],
+             [3, 275.0, 27.5],
+             [4, 285.0, 28.5],
+             [5, 295.0, 29.5],
+             [6, 305.0, 30.5],
+             [7, 315.0, 31.5],
+             [8, 325.0, 32.5],
+             [9, 335.0, 33.5]],
+            ['id', 'sum(v)', 'avg(v)'])
+
+        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+
+    def test_basic(self):
+        from pyspark.sql.functions import col, lit, sum, mean
+
+        df = self.data
+        weighted_mean_udf = self.pandas_agg_weighted_mean_udf
+
+        # Groupby one column and aggregate one UDF with literal
+        result1 = df.groupby('id').agg(weighted_mean_udf(df.v, lit(1.0))).sort('id')
+        expected1 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort('id')
+        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+
+        # Groupby one expression and aggregate one UDF with literal
+        result2 = df.groupby((col('id') + 1)).agg(weighted_mean_udf(df.v, lit(1.0)))\
+            .sort(df.id + 1)
+        expected2 = df.groupby((col('id') + 1))\
+            .agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort(df.id + 1)
+        self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
+
+        # Groupby one column and aggregate one UDF without literal
+        result3 = df.groupby('id').agg(weighted_mean_udf(df.v, df.w)).sort('id')
+        expected3 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, w)')).sort('id')
+        self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
+
+        # Groupby one expression and aggregate one UDF without literal
+        result4 = df.groupby((col('id') + 1).alias('id'))\
+            .agg(weighted_mean_udf(df.v, df.w))\
+            .sort('id')
+        expected4 = df.groupby((col('id') + 1).alias('id'))\
+            .agg(mean(df.v).alias('weighted_mean(v, w)'))\
+            .sort('id')
+        self.assertPandasEqual(expected4.toPandas(), result4.toPandas())
+
+    def test_unsupported_types(self):
+        from pyspark.sql.types import ArrayType, DoubleType, MapType
+        from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+        with QuietTest(self.sc):
+            with self.assertRaisesRegex(NotImplementedError, 'not supported'):
+                @pandas_udf(ArrayType(DoubleType()), PandasUDFType.GROUP_AGG)
+                def mean_and_std_udf(v):
+                    return [v.mean(), v.std()]
+
+        with QuietTest(self.sc):
+            with self.assertRaisesRegex(NotImplementedError, 'not supported'):
+                @pandas_udf('mean double, std double', PandasUDFType.GROUP_AGG)
+                def mean_and_std_udf(v):
+                    return v.mean(), v.std()
+
+        with QuietTest(self.sc):
+            with self.assertRaisesRegex(NotImplementedError, 'not supported'):
+                @pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUP_AGG)
+                def mean_and_std_udf(v):
+                    return {v.mean(): v.std()}
+
+    def test_alias(self):
+        from pyspark.sql.functions import mean
+
+        df = self.data
+        mean_udf = self.pandas_agg_mean_udf
+
+        result1 = df.groupby('id').agg(mean_udf(df.v).alias('mean_alias'))
+        expected1 = df.groupby('id').agg(mean(df.v).alias('mean_alias'))
+
+        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+
+    def test_mixed_sql(self):
+        """
+        Test mixing group aggregate pandas UDF with sql expression.
+        """
+        from pyspark.sql.functions import sum, mean
+
+        df = self.data
+        sum_udf = self.pandas_agg_sum_udf
+
+        # Mix group aggregate pandas UDF with sql expression
+        result1 = (df.groupby('id')
+                   .agg(sum_udf(df.v) + 1)
+                   .sort('id'))
+        expected1 = (df.groupby('id')
+                     .agg(sum(df.v) + 1)
+                     .sort('id'))
+
+        # Mix group aggregate pandas UDF with sql expression (order swapped)
+        result2 = (df.groupby('id')
+                     .agg(sum_udf(df.v + 1))
+                     .sort('id'))
+
+        expected2 = (df.groupby('id')
+                       .agg(sum(df.v + 1))
+                       .sort('id'))
+
+        # Wrap group aggregate pandas UDF with two sql expressions
+        result3 = (df.groupby('id')
+                   .agg(sum_udf(df.v + 1) + 2)
+                   .sort('id'))
+        expected3 = (df.groupby('id')
+                     .agg(sum(df.v + 1) + 2)
+                     .sort('id'))
+
+        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+        self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
+        self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
+
+    def test_mixed_udfs(self):
+        """
+        Test mixing group aggregate pandas UDF with python UDF and scalar pandas UDF.
+        """
+        from pyspark.sql.functions import sum, mean
+
+        df = self.data
+        plus_one = self.python_plus_one
+        plus_two = self.pandas_scalar_plus_two
+        sum_udf = self.pandas_agg_sum_udf
+
+        # Mix group aggregate pandas UDF and python UDF
+        result1 = (df.groupby('id')
+                   .agg(plus_one(sum_udf(df.v)))
+                   .sort('id'))
+        expected1 = (df.groupby('id')
+                     .agg(plus_one(sum(df.v)))
+                     .sort('id'))
+
+        # Mix group aggregate pandas UDF and python UDF (order swapped)
+        result2 = (df.groupby('id')
+                   .agg(sum_udf(plus_one(df.v)))
+                   .sort('id'))
+        expected2 = (df.groupby('id')
+                     .agg(sum(plus_one(df.v)))
+                     .sort('id'))
+
+        # Mix group aggregate pandas UDF and scalar pandas UDF
+        result3 = (df.groupby('id')
+                   .agg(sum_udf(plus_two(df.v)))
+                   .sort('id'))
+        expected3 = (df.groupby('id')
+                     .agg(sum(plus_two(df.v)))
+                     .sort('id'))
+
+        # Mix group aggregate pandas UDF and scalar pandas UDF (order swapped)
+        result4 = (df.groupby('id')
+                   .agg(plus_two(sum_udf(df.v)))
+                   .sort('id'))
+        expected4 = (df.groupby('id')
+                     .agg(plus_two(sum(df.v)))
+                     .sort('id'))
+
+        # Wrap group aggregate pandas UDF with two python UDFs and use python UDF in groupby
+        result5 = (df.groupby(plus_one(df.id))
+                   .agg(plus_one(sum_udf(plus_one(df.v))))
+                   .sort('plus_one(id)'))
+        expected5 = (df.groupby(plus_one(df.id))
+                     .agg(plus_one(sum(plus_one(df.v))))
+                     .sort('plus_one(id)'))
+
+        # Wrap group aggregate pandas UDF with two scala pandas UDF and user scala pandas UDF in
+        # groupby
+        result6 = (df.groupby(plus_two(df.id))
+                   .agg(plus_two(sum_udf(plus_two(df.v))))
+                   .sort('plus_two(id)'))
+        expected6 = (df.groupby(plus_two(df.id))
+                     .agg(plus_two(sum(plus_two(df.v))))
+                     .sort('plus_two(id)'))
+
+        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+        self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
+        self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
+        self.assertPandasEqual(expected4.toPandas(), result4.toPandas())
+        self.assertPandasEqual(expected5.toPandas(), result5.toPandas())
+        self.assertPandasEqual(expected6.toPandas(), result6.toPandas())
+
+    def test_multiple_udfs(self):
+        """
+        Test multiple group aggregate pandas UDFs in one agg function.
+        """
+        from pyspark.sql.functions import col, lit, sum, mean
+
+        df = self.data
+        mean_udf = self.pandas_agg_mean_udf
+        sum_udf = self.pandas_agg_sum_udf
+        weighted_mean_udf = self.pandas_agg_weighted_mean_udf
+
+        result1 = (df.groupBy('id')
+                   .agg(mean_udf(df.v),
+                        sum_udf(df.v),
+                        weighted_mean_udf(df.v, df.w))
+                   .sort('id')
+                   .toPandas())
+        expected1 = (df.groupBy('id')
+                     .agg(mean(df.v),
+                          sum(df.v),
+                          mean(df.v).alias('weighted_mean(v, w)'))
+                     .sort('id')
+                     .toPandas())
+
+        self.assertPandasEqual(expected1, result1)
+
+    def test_complex_groupby(self):
+        from pyspark.sql.functions import lit, sum
+
+        df = self.data
+        sum_udf = self.pandas_agg_sum_udf
+        plus_one = self.python_plus_one
+        plus_two = self.pandas_scalar_plus_two
+
+        # groupby one expression
+        result1 = df.groupby(df.v % 2).agg(sum_udf(df.v))
+        expected1 = df.groupby(df.v % 2).agg(sum(df.v))
+
+        # empty groupby
+        result2 = df.groupby().agg(sum_udf(df.v))
+        expected2 = df.groupby().agg(sum(df.v))
+
+        # groupby one column and one sql expression
+        result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v))
+        expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v))
+
+        # groupby one python UDF
+        result4 = df.groupby(plus_one(df.id)).agg(sum_udf(df.v))
+        expected4 = df.groupby(plus_one(df.id)).agg(sum(df.v))
+
+        # groupby one scalar pandas UDF
+        result5 = df.groupby(plus_two(df.id)).agg(sum_udf(df.v))
+        expected5 = df.groupby(plus_two(df.id)).agg(sum(df.v))
+
+        # groupby one expression and one python UDF
+        result6 = df.groupby(df.v % 2, plus_one(df.id)).agg(sum_udf(df.v))
+        expected6 = df.groupby(df.v % 2, plus_one(df.id)).agg(sum(df.v))
+
+        # groupby one expression and one scalar pandas UDF
+        result7 = df.groupby(df.v % 2, plus_two(df.id)).agg(sum_udf(df.v)).sort('sum(v)')
+        expected7 = df.groupby(df.v % 2, plus_two(df.id)).agg(sum(df.v)).sort('sum(v)')
+
+        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+        self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
+        self.assertPandasEqual(expected3.toPandas(), result3.toPandas())
+        self.assertPandasEqual(expected4.toPandas(), result4.toPandas())
+        self.assertPandasEqual(expected5.toPandas(), result5.toPandas())
+        self.assertPandasEqual(expected6.toPandas(), result6.toPandas())
+        self.assertPandasEqual(expected7.toPandas(), result7.toPandas())
+
+    def test_complex_expressions(self):
+        from pyspark.sql.functions import col, sum
+
+        df = self.data
+        plus_one = self.python_plus_one
+        plus_two = self.pandas_scalar_plus_two
+        sum_udf = self.pandas_agg_sum_udf
+
+        # Test complex expressions with sql expression, python UDF and
+        # group aggregate pandas UDF
+        result1 = (df.withColumn('v1', plus_one(df.v))
+                   .withColumn('v2', df.v + 2)
+                   .groupby(df.id, df.v % 2)
+                   .agg(sum_udf(col('v')),
+                        sum_udf(col('v1') + 3),
+                        sum_udf(col('v2')) + 5,
+                        plus_one(sum_udf(col('v1'))),
+                        sum_udf(plus_one(col('v2'))))
+                   .sort('id')
+                   .toPandas())
+
+        expected1 = (df.withColumn('v1', df.v + 1)
+                     .withColumn('v2', df.v + 2)
+                     .groupby(df.id, df.v % 2)
+                     .agg(sum(col('v')),
+                          sum(col('v1') + 3),
+                          sum(col('v2')) + 5,
+                          plus_one(sum(col('v1'))),
+                          sum(plus_one(col('v2'))))
+                     .sort('id')
+                     .toPandas())
+
+        # Test complex expressions with sql expression, scala pandas UDF and
+        # group aggregate pandas UDF
+        result2 = (df.withColumn('v1', plus_one(df.v))
+                   .withColumn('v2', df.v + 2)
+                   .groupby(df.id, df.v % 2)
+                   .agg(sum_udf(col('v')),
+                        sum_udf(col('v1') + 3),
+                        sum_udf(col('v2')) + 5,
+                        plus_two(sum_udf(col('v1'))),
+                        sum_udf(plus_two(col('v2'))))
+                   .sort('id')
+                   .toPandas())
+
+        expected2 = (df.withColumn('v1', df.v + 1)
+                     .withColumn('v2', df.v + 2)
+                     .groupby(df.id, df.v % 2)
+                     .agg(sum(col('v')),
+                          sum(col('v1') + 3),
+                          sum(col('v2')) + 5,
+                          plus_two(sum(col('v1'))),
+                          sum(plus_two(col('v2'))))
+                     .sort('id')
+                     .toPandas())
+
+        # Test sequential groupby aggregate
+        result3 = (df.groupby('id')
+                   .agg(sum_udf(df.v).alias('v'))
+                   .groupby('id')
+                   .agg(sum_udf(col('v')))
+                   .sort('id')
+                   .toPandas())
+
+        expected3 = (df.groupby('id')
+                     .agg(sum(df.v).alias('v'))
+                     .groupby('id')
+                     .agg(sum(col('v')))
+                     .sort('id')
+                     .toPandas())
+
+        self.assertPandasEqual(expected1, result1)
+        self.assertPandasEqual(expected2, result2)
+        self.assertPandasEqual(expected3, result3)
+
+    def test_retain_group_columns(self):
+        from pyspark.sql.functions import sum, lit, col
+        orig_value = self.spark.conf.get("spark.sql.retainGroupColumns", None)
+        self.spark.conf.set("spark.sql.retainGroupColumns", False)
+        try:
+            df = self.data
+            sum_udf = self.pandas_agg_sum_udf
+
+            result1 = df.groupby(df.id).agg(sum_udf(df.v))
+            expected1 = df.groupby(df.id).agg(sum(df.v))
+            self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+
+        finally:
+            if orig_value is None:
+                self.spark.conf.unset("spark.sql.retainGroupColumns")
+            else:
+                self.spark.conf.set("spark.sql.retainGroupColumns", orig_value)
+
+    def test_invalid_args(self):
+        from pyspark.sql.functions import mean
+
+        df = self.data
+        plus_one = self.python_plus_one
+        mean_udf = self.pandas_agg_mean_udf
+
+        with QuietTest(self.sc):
+            with self.assertRaisesRegexp(
+                    AnalysisException,
+                    'nor.*aggregate function'):
+                df.groupby(df.id).agg(plus_one(df.v)).collect()
+
+        with QuietTest(self.sc):
+            with self.assertRaisesRegexp(
+                    AnalysisException,
+                    'aggregate function.*argument.*aggregate function'):
+                df.groupby(df.id).agg(mean_udf(mean_udf(df.v))).collect()
+
+        with QuietTest(self.sc):
+            with self.assertRaisesRegexp(
+                    AnalysisException,
+                    'mixture.*aggregate function.*group aggregate pandas UDF'):
+                df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()
+
 if __name__ == "__main__":
     from pyspark.sql.tests import *
     if xmlrunner:

http://git-wip-us.apache.org/repos/asf/spark/blob/b2ce17b4/python/pyspark/sql/udf.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 134badb..de96846 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -22,7 +22,8 @@ import functools
 from pyspark import SparkContext, since
 from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix
 from pyspark.sql.column import Column, _to_java_column, _to_seq
-from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string
+from pyspark.sql.types import StringType, DataType, ArrayType, StructType, MapType, \
+    _parse_datatype_string
 
 __all__ = ["UDFRegistration"]
 
@@ -36,8 +37,10 @@ def _wrap_function(sc, func, returnType):
 
 def _create_udf(f, returnType, evalType):
 
-    if evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF or \
-            evalType == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF:
+    if evalType in (PythonEvalType.SQL_PANDAS_SCALAR_UDF,
+                    PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF,
+                    PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF):
+
         import inspect
         from pyspark.sql.utils import require_minimum_pyarrow_version
 
@@ -113,6 +116,10 @@ class UserDefinedFunction(object):
                 and not isinstance(self._returnType_placeholder, StructType):
             raise ValueError("Invalid returnType: returnType must be a StructType for "
                              "pandas_udf with function type GROUP_MAP")
+        elif self.evalType == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF \
+                and isinstance(self._returnType_placeholder, (StructType, ArrayType, MapType)):
+            raise NotImplementedError(
+                "ArrayType, StructType and MapType are not supported with PandasUDFType.GROUP_AGG")
 
         return self._returnType_placeholder
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b2ce17b4/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index e6737ae..173d8fb 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -110,6 +110,17 @@ def wrap_pandas_group_map_udf(f, return_type):
     return wrapped
 
 
+def wrap_pandas_group_agg_udf(f, return_type):
+    arrow_return_type = to_arrow_type(return_type)
+
+    def wrapped(*series):
+        import pandas as pd
+        result = f(*series)
+        return pd.Series(result)
+
+    return lambda *a: (wrapped(*a), arrow_return_type)
+
+
 def read_single_udf(pickleSer, infile, eval_type):
     num_arg = read_int(infile)
     arg_offsets = [read_int(infile) for i in range(num_arg)]
@@ -126,8 +137,12 @@ def read_single_udf(pickleSer, infile, eval_type):
         return arg_offsets, wrap_pandas_scalar_udf(row_func, return_type)
     elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF:
         return arg_offsets, wrap_pandas_group_map_udf(row_func, return_type)
-    else:
+    elif eval_type == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF:
+        return arg_offsets, wrap_pandas_group_agg_udf(row_func, return_type)
+    elif eval_type == PythonEvalType.SQL_BATCHED_UDF:
         return arg_offsets, wrap_udf(row_func, return_type)
+    else:
+        raise ValueError("Unknown eval type: {}".format(eval_type))
 
 
 def read_udfs(pickleSer, infile, eval_type):
@@ -148,8 +163,9 @@ def read_udfs(pickleSer, infile, eval_type):
 
     func = lambda _, it: map(mapper, it)
 
-    if eval_type == PythonEvalType.SQL_PANDAS_SCALAR_UDF \
-       or eval_type == PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF:
+    if eval_type in (PythonEvalType.SQL_PANDAS_SCALAR_UDF,
+                     PythonEvalType.SQL_PANDAS_GROUP_MAP_UDF,
+                     PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF):
         timezone = utf8_deserializer.loads(infile)
         ser = ArrowStreamPandasSerializer(timezone)
     else:

http://git-wip-us.apache.org/repos/asf/spark/blob/b2ce17b4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index bbcec56..ef91d79 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -153,11 +153,19 @@ trait CheckAnalysis extends PredicateHelper {
                 s"of type ${condition.dataType.simpleString} is not a boolean.")
 
           case Aggregate(groupingExprs, aggregateExprs, child) =>
+            def isAggregateExpression(expr: Expression) = {
+              expr.isInstanceOf[AggregateExpression] || PythonUDF.isGroupAggPandasUDF(expr)
+            }
+
             def checkValidAggregateExpression(expr: Expression): Unit = expr match {
-              case aggExpr: AggregateExpression =>
-                aggExpr.aggregateFunction.children.foreach { child =>
+              case expr: Expression if isAggregateExpression(expr) =>
+                val aggFunction = expr match {
+                  case agg: AggregateExpression => agg.aggregateFunction
+                  case udf: PythonUDF => udf
+                }
+                aggFunction.children.foreach { child =>
                   child.foreach {
-                    case agg: AggregateExpression =>
+                    case expr: Expression if isAggregateExpression(expr) =>
                       failAnalysis(
                         s"It is not allowed to use an aggregate function in the argument of " +
                           s"another aggregate function. Please use the inner aggregate function " +

http://git-wip-us.apache.org/repos/asf/spark/blob/b2ce17b4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
new file mode 100644
index 0000000..4ba8ff6
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.api.python.{PythonEvalType, PythonFunction}
+import org.apache.spark.sql.catalyst.util.toPrettySQL
+import org.apache.spark.sql.types.DataType
+
+/**
+ * Helper functions for [[PythonUDF]]
+ */
+object PythonUDF {
+  private[this] val SCALAR_TYPES = Set(
+    PythonEvalType.SQL_BATCHED_UDF,
+    PythonEvalType.SQL_PANDAS_SCALAR_UDF
+  )
+
+  def isScalarPythonUDF(e: Expression): Boolean = {
+    e.isInstanceOf[PythonUDF] && SCALAR_TYPES.contains(e.asInstanceOf[PythonUDF].evalType)
+  }
+
+  def isGroupAggPandasUDF(e: Expression): Boolean = {
+    e.isInstanceOf[PythonUDF] &&
+      e.asInstanceOf[PythonUDF].evalType == PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF
+  }
+}
+
+/**
+ * A serialized version of a Python lambda function.
+ */
+case class PythonUDF(
+    name: String,
+    func: PythonFunction,
+    dataType: DataType,
+    children: Seq[Expression],
+    evalType: Int,
+    udfDeterministic: Boolean,
+    resultId: ExprId = NamedExpression.newExprId)
+  extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression {
+
+  override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)
+
+  override def toString: String = s"$name(${children.mkString(", ")})"
+
+  lazy val resultAttribute: Attribute = AttributeReference(toPrettySQL(this), dataType, nullable)(
+    exprId = resultId)
+
+  override def nullable: Boolean = true
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/b2ce17b4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index cc391aa..1322410 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.catalyst.planning
 
+import org.apache.spark.api.python.PythonEvalType
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
@@ -199,7 +200,7 @@ object ExtractFiltersAndInnerJoins extends PredicateHelper {
 object PhysicalAggregation {
   // groupingExpressions, aggregateExpressions, resultExpressions, child
   type ReturnType =
-    (Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], LogicalPlan)
+    (Seq[NamedExpression], Seq[Expression], Seq[NamedExpression], LogicalPlan)
 
   def unapply(a: Any): Option[ReturnType] = a match {
     case logical.Aggregate(groupingExpressions, resultExpressions, child) =>
@@ -213,7 +214,10 @@ object PhysicalAggregation {
         expr.collect {
           // addExpr() always returns false for non-deterministic expressions and do not add them.
           case agg: AggregateExpression
-            if (!equivalentAggregateExpressions.addExpr(agg)) => agg
+            if !equivalentAggregateExpressions.addExpr(agg) => agg
+          case udf: PythonUDF
+            if PythonUDF.isGroupAggPandasUDF(udf) &&
+              !equivalentAggregateExpressions.addExpr(udf) => udf
         }
       }
 
@@ -241,6 +245,10 @@ object PhysicalAggregation {
             // so replace each aggregate expression by its corresponding attribute in the set:
             equivalentAggregateExpressions.getEquivalentExprs(ae).headOption
               .getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute
+            // Similar to AggregateExpression
+          case ue: PythonUDF if PythonUDF.isGroupAggPandasUDF(ue) =>
+            equivalentAggregateExpressions.getEquivalentExprs(ue).headOption
+              .getOrElse(ue).asInstanceOf[PythonUDF].resultAttribute
           case expression =>
             // Since we're using `namedGroupingAttributes` to extract the grouping key
             // columns, we need to replace grouping key expressions with their corresponding

http://git-wip-us.apache.org/repos/asf/spark/blob/b2ce17b4/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index a009c00..d320c1c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.util.toPrettySQL
 import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
-import org.apache.spark.sql.execution.python.PythonUDF
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{NumericType, StructType}
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b2ce17b4/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 9102948..ce512bc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.{execution, AnalysisException, Strategy}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.planning._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
@@ -288,9 +289,14 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
       case PhysicalAggregation(
         namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) =>
 
+        if (aggregateExpressions.exists(PythonUDF.isGroupAggPandasUDF)) {
+          throw new AnalysisException(
+            "Streaming aggregation doesn't support group aggregate pandas UDF")
+        }
+
         aggregate.AggUtils.planStreamingAggregation(
           namedGroupingExpressions,
-          aggregateExpressions,
+          aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]),
           rewrittenResultExpressions,
           planLater(child))
 
@@ -333,8 +339,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
    */
   object Aggregation extends Strategy {
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
-      case PhysicalAggregation(
-          groupingExpressions, aggregateExpressions, resultExpressions, child) =>
+      case PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child)
+        if aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression]) =>
+        val aggregateExpressions = aggExpressions.map(expr =>
+          expr.asInstanceOf[AggregateExpression])
 
         val (functionsWithDistinct, functionsWithoutDistinct) =
           aggregateExpressions.partition(_.isDistinct)
@@ -363,6 +371,21 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
 
         aggregateOperator
 
+      case PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child)
+        if aggExpressions.forall(expr => expr.isInstanceOf[PythonUDF]) =>
+        val udfExpressions = aggExpressions.map(expr => expr.asInstanceOf[PythonUDF])
+
+        Seq(execution.python.AggregateInPandasExec(
+          groupingExpressions,
+          udfExpressions,
+          resultExpressions,
+          planLater(child)))
+
+      case PhysicalAggregation(_, _, _, _) =>
+        // If cannot match the two cases above, then it's an error
+        throw new AnalysisException(
+          "Cannot use a mixture of aggregate function and group aggregate pandas UDF")
+
       case _ => Nil
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/b2ce17b4/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
new file mode 100644
index 0000000..18e5f86
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io.File
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
+import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.types.{DataType, StructField, StructType}
+import org.apache.spark.util.Utils
+
+/**
+ * Physical node for aggregation with group aggregate Pandas UDF.
+ *
+ * This plan works by sending the necessary (projected) input grouped data as Arrow record batches
+ * to the python worker, the python worker invokes the UDF and sends the results to the executor,
+ * finally the executor evaluates any post-aggregation expressions and join the result with the
+ * grouped key.
+ */
+case class AggregateInPandasExec(
+    groupingExpressions: Seq[NamedExpression],
+    udfExpressions: Seq[PythonUDF],
+    resultExpressions: Seq[NamedExpression],
+    child: SparkPlan)
+  extends UnaryExecNode {
+
+  override val output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+
+  override def outputPartitioning: Partitioning = child.outputPartitioning
+
+  override def producedAttributes: AttributeSet = AttributeSet(output)
+
+  override def requiredChildDistribution: Seq[Distribution] = {
+    if (groupingExpressions.isEmpty) {
+      AllTuples :: Nil
+    } else {
+      ClusteredDistribution(groupingExpressions) :: Nil
+    }
+  }
+
+  private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
+    udf.children match {
+      case Seq(u: PythonUDF) =>
+        val (chained, children) = collectFunctions(u)
+        (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
+      case children =>
+        // There should not be any other UDFs, or the children can't be evaluated directly.
+        assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
+        (ChainedPythonFunctions(Seq(udf.func)), udf.children)
+    }
+  }
+
+  override def requiredChildOrdering: Seq[Seq[SortOrder]] =
+    Seq(groupingExpressions.map(SortOrder(_, Ascending)))
+
+  override protected def doExecute(): RDD[InternalRow] = {
+    val inputRDD = child.execute()
+
+    val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
+    val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
+    val sessionLocalTimeZone = conf.sessionLocalTimeZone
+    val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone
+
+    val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip
+
+    // Filter child output attributes down to only those that are UDF inputs.
+    // Also eliminate duplicate UDF inputs.
+    val allInputs = new ArrayBuffer[Expression]
+    val dataTypes = new ArrayBuffer[DataType]
+    val argOffsets = inputs.map { input =>
+      input.map { e =>
+        if (allInputs.exists(_.semanticEquals(e))) {
+          allInputs.indexWhere(_.semanticEquals(e))
+        } else {
+          allInputs += e
+          dataTypes += e.dataType
+          allInputs.length - 1
+        }
+      }.toArray
+    }.toArray
+
+    // Schema of input rows to the python runner
+    val aggInputSchema = StructType(dataTypes.zipWithIndex.map { case (dt, i) =>
+      StructField(s"_$i", dt)
+    })
+
+    inputRDD.mapPartitionsInternal { iter =>
+      val prunedProj = UnsafeProjection.create(allInputs, child.output)
+
+      val grouped = if (groupingExpressions.isEmpty) {
+        // Use an empty unsafe row as a place holder for the grouping key
+        Iterator((new UnsafeRow(), iter))
+      } else {
+        GroupedIterator(iter, groupingExpressions, child.output)
+      }.map { case (key, rows) =>
+        (key, rows.map(prunedProj))
+      }
+
+      val context = TaskContext.get()
+
+      // The queue used to buffer input rows so we can drain it to
+      // combine input with output from Python.
+      val queue = HybridRowQueue(context.taskMemoryManager(),
+        new File(Utils.getLocalDir(SparkEnv.get.conf)), groupingExpressions.length)
+      context.addTaskCompletionListener { _ =>
+        queue.close()
+      }
+
+      // Add rows to queue to join later with the result.
+      val projectedRowIter = grouped.map { case (groupingKey, rows) =>
+        queue.add(groupingKey.asInstanceOf[UnsafeRow])
+        rows
+      }
+
+      val columnarBatchIter = new ArrowPythonRunner(
+        pyFuncs, bufferSize, reuseWorker,
+        PythonEvalType.SQL_PANDAS_GROUP_AGG_UDF, argOffsets, aggInputSchema,
+        sessionLocalTimeZone, pandasRespectSessionTimeZone)
+        .compute(projectedRowIter, context.partitionId(), context)
+
+      val joinedAttributes =
+        groupingExpressions.map(_.toAttribute) ++ udfExpressions.map(_.resultAttribute)
+      val joined = new JoinedRow
+      val resultProj = UnsafeProjection.create(resultExpressions, joinedAttributes)
+
+      columnarBatchIter.map(_.rowIterator.next()).map { aggOutputRow =>
+        val leftRow = queue.remove()
+        val joinedRow = joined(leftRow, aggOutputRow)
+        resultProj(joinedRow)
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/b2ce17b4/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index 2f53fe7..1862e3f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -39,12 +39,13 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
    */
   private def belongAggregate(e: Expression, agg: Aggregate): Boolean = {
     e.isInstanceOf[AggregateExpression] ||
+      PythonUDF.isGroupAggPandasUDF(e) ||
       agg.groupingExpressions.exists(_.semanticEquals(e))
   }
 
   private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): Boolean = {
     expr.find {
-      e => e.isInstanceOf[PythonUDF] && e.find(belongAggregate(_, agg)).isDefined
+      e => PythonUDF.isScalarPythonUDF(e) && e.find(belongAggregate(_, agg)).isDefined
     }.isDefined
   }
 
@@ -93,7 +94,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
 object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
 
   private def hasPythonUDF(e: Expression): Boolean = {
-    e.find(_.isInstanceOf[PythonUDF]).isDefined
+    e.find(PythonUDF.isScalarPythonUDF).isDefined
   }
 
   private def canEvaluateInPython(e: PythonUDF): Boolean = {
@@ -106,12 +107,12 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
   }
 
   private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match {
-    case udf: PythonUDF if canEvaluateInPython(udf) => Seq(udf)
+    case udf: PythonUDF if PythonUDF.isScalarPythonUDF(udf) && canEvaluateInPython(udf) => Seq(udf)
     case e => e.children.flatMap(collectEvaluatableUDF)
   }
 
   def apply(plan: SparkPlan): SparkPlan = plan transformUp {
-    // FlatMapGroupsInPandas can be evaluated directly in python worker
+    // AggregateInPandasExec and FlatMapGroupsInPandas can be evaluated directly in python worker
     // Therefore we don't need to extract the UDFs
     case plan: FlatMapGroupsInPandasExec => plan
     case plan: SparkPlan => extract(plan)
@@ -149,10 +150,9 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
           udf.references.subsetOf(child.outputSet)
         }
         if (validUdfs.nonEmpty) {
-          require(validUdfs.forall(udf =>
-            udf.evalType == PythonEvalType.SQL_BATCHED_UDF ||
-            udf.evalType == PythonEvalType.SQL_PANDAS_SCALAR_UDF
-          ), "Can only extract scalar vectorized udf or sql batch udf")
+          require(
+            validUdfs.forall(PythonUDF.isScalarPythonUDF),
+            "Can only extract scalar vectorized udf or sql batch udf")
 
           val resultAttrs = udfs.zipWithIndex.map { case (u, i) =>
             AttributeReference(s"pythonUDF$i", u.dataType)()

http://git-wip-us.apache.org/repos/asf/spark/blob/b2ce17b4/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
deleted file mode 100644
index d3f743d..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.python
-
-import org.apache.spark.api.python.PythonFunction
-import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable, UserDefinedExpression}
-import org.apache.spark.sql.types.DataType
-
-/**
- * A serialized version of a Python lambda function.
- */
-case class PythonUDF(
-    name: String,
-    func: PythonFunction,
-    dataType: DataType,
-    children: Seq[Expression],
-    evalType: Int,
-    udfDeterministic: Boolean)
-  extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression {
-
-  override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic)
-
-  override def toString: String = s"$name(${children.mkString(", ")})"
-
-  override def nullable: Boolean = true
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/b2ce17b4/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index 50dca32..f4c2d02 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.python
 
 import org.apache.spark.api.python.PythonFunction
 import org.apache.spark.sql.Column
-import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.{Expression, PythonUDF}
 import org.apache.spark.sql.types.DataType
 
 /**


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org