You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2018/03/08 11:29:11 UTC

spark git commit: [SPARK-23011][SQL][PYTHON] Support alternative function form with group aggregate pandas UDF

Repository: spark
Updated Branches:
  refs/heads/master d6632d185 -> 2cb23a8f5


[SPARK-23011][SQL][PYTHON] Support alternative function form with group aggregate pandas UDF

## What changes were proposed in this pull request?

This PR proposes to support an alternative function from with group aggregate pandas UDF.

The current form:
```
def foo(pdf):
    return ...
```
Takes a single arg that is a pandas DataFrame.

With this PR, an alternative form is supported:
```
def foo(key, pdf):
    return ...
```
The alternative form takes two argument - a tuple that presents the grouping key, and a pandas DataFrame represents the data.

## How was this patch tested?

GroupbyApplyTests

Author: Li Jin <ic...@gmail.com>

Closes #20295 from icexelloss/SPARK-23011-groupby-apply-key.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2cb23a8f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2cb23a8f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2cb23a8f

Branch: refs/heads/master
Commit: 2cb23a8f51a151970c121015fcbad9beeafa8295
Parents: d6632d1
Author: Li Jin <ic...@gmail.com>
Authored: Thu Mar 8 20:29:07 2018 +0900
Committer: hyukjinkwon <gu...@gmail.com>
Committed: Thu Mar 8 20:29:07 2018 +0900

----------------------------------------------------------------------
 python/pyspark/serializers.py                   |  18 +--
 python/pyspark/sql/functions.py                 |  25 ++++
 python/pyspark/sql/tests.py                     | 121 +++++++++++++++++--
 python/pyspark/sql/types.py                     |  45 +++++--
 python/pyspark/sql/udf.py                       |  19 ++-
 python/pyspark/util.py                          |  16 +++
 python/pyspark/worker.py                        |  49 ++++++--
 .../python/FlatMapGroupsInPandasExec.scala      |  56 ++++++++-
 8 files changed, 294 insertions(+), 55 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2cb23a8f/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 917e258..ebf5493 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -250,6 +250,15 @@ class ArrowStreamPandasSerializer(Serializer):
         super(ArrowStreamPandasSerializer, self).__init__()
         self._timezone = timezone
 
+    def arrow_to_pandas(self, arrow_column):
+        from pyspark.sql.types import from_arrow_type, \
+            _check_series_convert_date, _check_series_localize_timestamps
+
+        s = arrow_column.to_pandas()
+        s = _check_series_convert_date(s, from_arrow_type(arrow_column.type))
+        s = _check_series_localize_timestamps(s, self._timezone)
+        return s
+
     def dump_stream(self, iterator, stream):
         """
         Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or
@@ -272,16 +281,11 @@ class ArrowStreamPandasSerializer(Serializer):
         """
         Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
         """
-        from pyspark.sql.types import from_arrow_schema, _check_dataframe_convert_date, \
-            _check_dataframe_localize_timestamps
         import pyarrow as pa
         reader = pa.open_stream(stream)
-        schema = from_arrow_schema(reader.schema)
+
         for batch in reader:
-            pdf = batch.to_pandas()
-            pdf = _check_dataframe_convert_date(pdf, schema)
-            pdf = _check_dataframe_localize_timestamps(pdf, self._timezone)
-            yield [c for _, c in pdf.iteritems()]
+            yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()]
 
     def __repr__(self):
         return "ArrowStreamPandasSerializer"

http://git-wip-us.apache.org/repos/asf/spark/blob/2cb23a8f/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index b9c0c57..dc1341a 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2267,6 +2267,31 @@ def pandas_udf(f=None, returnType=None, functionType=None):
        |  2| 1.1094003924504583|
        +---+-------------------+
 
+       Alternatively, the user can define a function that takes two arguments.
+       In this case, the grouping key will be passed as the first argument and the data will
+       be passed as the second argument. The grouping key will be passed as a tuple of numpy
+       data types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in
+       as a `pandas.DataFrame` containing all columns from the original Spark DataFrame.
+       This is useful when the user does not want to hardcode grouping key in the function.
+
+       >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
+       >>> import pandas as pd  # doctest: +SKIP
+       >>> df = spark.createDataFrame(
+       ...     [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
+       ...     ("id", "v"))  # doctest: +SKIP
+       >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP)  # doctest: +SKIP
+       ... def mean_udf(key, pdf):
+       ...     # key is a tuple of one numpy.int64, which is the value
+       ...     # of 'id' for the current group
+       ...     return pd.DataFrame([key + (pdf.v.mean(),)])
+       >>> df.groupby('id').apply(mean_udf).show()  # doctest: +SKIP
+       +---+---+
+       | id|  v|
+       +---+---+
+       |  1|1.5|
+       |  2|6.0|
+       +---+---+
+
        .. seealso:: :meth:`pyspark.sql.GroupedData.apply`
 
     3. GROUPED_AGG

http://git-wip-us.apache.org/repos/asf/spark/blob/2cb23a8f/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index a9fe0b4..480815d 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3903,7 +3903,7 @@ class PandasUDFTests(ReusedSQLTestCase):
                     return df
             with self.assertRaisesRegexp(ValueError, 'Invalid function'):
                 @pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUPED_MAP)
-                def foo(k, v):
+                def foo(k, v, w):
                     return k
 
 
@@ -4476,20 +4476,45 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
         from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col
         df = self.data.withColumn("arr", array(col("id")))
 
-        foo_udf = pandas_udf(
+        # Different forms of group map pandas UDF, results of these are the same
+
+        output_schema = StructType(
+            [StructField('id', LongType()),
+             StructField('v', IntegerType()),
+             StructField('arr', ArrayType(LongType())),
+             StructField('v1', DoubleType()),
+             StructField('v2', LongType())])
+
+        udf1 = pandas_udf(
             lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
-            StructType(
-                [StructField('id', LongType()),
-                 StructField('v', IntegerType()),
-                 StructField('arr', ArrayType(LongType())),
-                 StructField('v1', DoubleType()),
-                 StructField('v2', LongType())]),
+            output_schema,
             PandasUDFType.GROUPED_MAP
         )
 
-        result = df.groupby('id').apply(foo_udf).sort('id').toPandas()
-        expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True)
-        self.assertPandasEqual(expected, result)
+        udf2 = pandas_udf(
+            lambda _, pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
+            output_schema,
+            PandasUDFType.GROUPED_MAP
+        )
+
+        udf3 = pandas_udf(
+            lambda key, pdf: pdf.assign(id=key[0], v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
+            output_schema,
+            PandasUDFType.GROUPED_MAP
+        )
+
+        result1 = df.groupby('id').apply(udf1).sort('id').toPandas()
+        expected1 = df.toPandas().groupby('id').apply(udf1.func).reset_index(drop=True)
+
+        result2 = df.groupby('id').apply(udf2).sort('id').toPandas()
+        expected2 = expected1
+
+        result3 = df.groupby('id').apply(udf3).sort('id').toPandas()
+        expected3 = expected1
+
+        self.assertPandasEqual(expected1, result1)
+        self.assertPandasEqual(expected2, result2)
+        self.assertPandasEqual(expected3, result3)
 
     def test_register_grouped_map_udf(self):
         from pyspark.sql.functions import pandas_udf, PandasUDFType
@@ -4648,6 +4673,80 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
         result = df.groupby('time').apply(foo_udf).sort('time')
         self.assertPandasEqual(df.toPandas(), result.toPandas())
 
+    def test_udf_with_key(self):
+        from pyspark.sql.functions import pandas_udf, col, PandasUDFType
+        df = self.data
+        pdf = df.toPandas()
+
+        def foo1(key, pdf):
+            import numpy as np
+            assert type(key) == tuple
+            assert type(key[0]) == np.int64
+
+            return pdf.assign(v1=key[0],
+                              v2=pdf.v * key[0],
+                              v3=pdf.v * pdf.id,
+                              v4=pdf.v * pdf.id.mean())
+
+        def foo2(key, pdf):
+            import numpy as np
+            assert type(key) == tuple
+            assert type(key[0]) == np.int64
+            assert type(key[1]) == np.int32
+
+            return pdf.assign(v1=key[0],
+                              v2=key[1],
+                              v3=pdf.v * key[0],
+                              v4=pdf.v + key[1])
+
+        def foo3(key, pdf):
+            assert type(key) == tuple
+            assert len(key) == 0
+            return pdf.assign(v1=pdf.v * pdf.id)
+
+        # v2 is int because numpy.int64 * pd.Series<int32> results in pd.Series<int32>
+        # v3 is long because pd.Series<int64> * pd.Series<int32> results in pd.Series<int64>
+        udf1 = pandas_udf(
+            foo1,
+            'id long, v int, v1 long, v2 int, v3 long, v4 double',
+            PandasUDFType.GROUPED_MAP)
+
+        udf2 = pandas_udf(
+            foo2,
+            'id long, v int, v1 long, v2 int, v3 int, v4 int',
+            PandasUDFType.GROUPED_MAP)
+
+        udf3 = pandas_udf(
+            foo3,
+            'id long, v int, v1 long',
+            PandasUDFType.GROUPED_MAP)
+
+        # Test groupby column
+        result1 = df.groupby('id').apply(udf1).sort('id', 'v').toPandas()
+        expected1 = pdf.groupby('id')\
+            .apply(lambda x: udf1.func((x.id.iloc[0],), x))\
+            .sort_values(['id', 'v']).reset_index(drop=True)
+        self.assertPandasEqual(expected1, result1)
+
+        # Test groupby expression
+        result2 = df.groupby(df.id % 2).apply(udf1).sort('id', 'v').toPandas()
+        expected2 = pdf.groupby(pdf.id % 2)\
+            .apply(lambda x: udf1.func((x.id.iloc[0] % 2,), x))\
+            .sort_values(['id', 'v']).reset_index(drop=True)
+        self.assertPandasEqual(expected2, result2)
+
+        # Test complex groupby
+        result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort('id', 'v').toPandas()
+        expected3 = pdf.groupby([pdf.id, pdf.v % 2])\
+            .apply(lambda x: udf2.func((x.id.iloc[0], (x.v % 2).iloc[0],), x))\
+            .sort_values(['id', 'v']).reset_index(drop=True)
+        self.assertPandasEqual(expected3, result3)
+
+        # Test empty groupby
+        result4 = df.groupby().apply(udf3).sort('id', 'v').toPandas()
+        expected4 = udf3.func((), pdf)
+        self.assertPandasEqual(expected4, result4)
+
 
 @unittest.skipIf(
     not _have_pandas or not _have_pyarrow,

http://git-wip-us.apache.org/repos/asf/spark/blob/2cb23a8f/python/pyspark/sql/types.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index cd85740..1632862 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1695,6 +1695,19 @@ def from_arrow_schema(arrow_schema):
          for field in arrow_schema])
 
 
+def _check_series_convert_date(series, data_type):
+    """
+    Cast the series to datetime.date if it's a date type, otherwise returns the original series.
+
+    :param series: pandas.Series
+    :param data_type: a Spark data type for the series
+    """
+    if type(data_type) == DateType:
+        return series.dt.date
+    else:
+        return series
+
+
 def _check_dataframe_convert_date(pdf, schema):
     """ Correct date type value to use datetime.date.
 
@@ -1705,8 +1718,7 @@ def _check_dataframe_convert_date(pdf, schema):
     :param schema: a Spark schema of the pandas.DataFrame
     """
     for field in schema:
-        if type(field.dataType) == DateType:
-            pdf[field.name] = pdf[field.name].dt.date
+        pdf[field.name] = _check_series_convert_date(pdf[field.name], field.dataType)
     return pdf
 
 
@@ -1725,6 +1737,29 @@ def _get_local_timezone():
     return os.environ.get('TZ', 'dateutil/:')
 
 
+def _check_series_localize_timestamps(s, timezone):
+    """
+    Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone.
+
+    If the input series is not a timestamp series, then the same series is returned. If the input
+    series is a timestamp series, then a converted series is returned.
+
+    :param s: pandas.Series
+    :param timezone: the timezone to convert. if None then use local timezone
+    :return pandas.Series that have been converted to tz-naive
+    """
+    from pyspark.sql.utils import require_minimum_pandas_version
+    require_minimum_pandas_version()
+
+    from pandas.api.types import is_datetime64tz_dtype
+    tz = timezone or _get_local_timezone()
+    # TODO: handle nested timestamps, such as ArrayType(TimestampType())?
+    if is_datetime64tz_dtype(s.dtype):
+        return s.dt.tz_convert(tz).dt.tz_localize(None)
+    else:
+        return s
+
+
 def _check_dataframe_localize_timestamps(pdf, timezone):
     """
     Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone
@@ -1736,12 +1771,8 @@ def _check_dataframe_localize_timestamps(pdf, timezone):
     from pyspark.sql.utils import require_minimum_pandas_version
     require_minimum_pandas_version()
 
-    from pandas.api.types import is_datetime64tz_dtype
-    tz = timezone or _get_local_timezone()
     for column, series in pdf.iteritems():
-        # TODO: handle nested timestamps, such as ArrayType(TimestampType())?
-        if is_datetime64tz_dtype(series.dtype):
-            pdf[column] = series.dt.tz_convert(tz).dt.tz_localize(None)
+        pdf[column] = _check_series_localize_timestamps(series, timezone)
     return pdf
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2cb23a8f/python/pyspark/sql/udf.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index b9b4908..ce804c1 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -17,6 +17,8 @@
 """
 User-defined function related classes and functions
 """
+import sys
+import inspect
 import functools
 
 from pyspark import SparkContext, since
@@ -24,6 +26,7 @@ from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_
 from pyspark.sql.column import Column, _to_java_column, _to_seq
 from pyspark.sql.types import StringType, DataType, ArrayType, StructType, MapType, \
     _parse_datatype_string, to_arrow_type, to_arrow_schema
+from pyspark.util import _get_argspec
 
 __all__ = ["UDFRegistration"]
 
@@ -41,18 +44,10 @@ def _create_udf(f, returnType, evalType):
                     PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
                     PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF):
 
-        import inspect
-        import sys
         from pyspark.sql.utils import require_minimum_pyarrow_version
-
         require_minimum_pyarrow_version()
 
-        if sys.version_info[0] < 3:
-            # `getargspec` is deprecated since python3.0 (incompatible with function annotations).
-            # See SPARK-23569.
-            argspec = inspect.getargspec(f)
-        else:
-            argspec = inspect.getfullargspec(f)
+        argspec = _get_argspec(f)
 
         if evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF and len(argspec.args) == 0 and \
                 argspec.varargs is None:
@@ -61,11 +56,11 @@ def _create_udf(f, returnType, evalType):
                 "Instead, create a 1-arg pandas_udf and ignore the arg in your function."
             )
 
-        if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF and len(argspec.args) != 1:
+        if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \
+                and len(argspec.args) not in (1, 2):
             raise ValueError(
                 "Invalid function: pandas_udfs with function type GROUPED_MAP "
-                "must take a single arg that is a pandas DataFrame."
-            )
+                "must take either one argument (data) or two arguments (key, data).")
 
     # Set the name of the UserDefinedFunction object to be the name of function f
     udf_obj = UserDefinedFunction(

http://git-wip-us.apache.org/repos/asf/spark/blob/2cb23a8f/python/pyspark/util.py
----------------------------------------------------------------------
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index ad4a0bc..6837b18 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -15,6 +15,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+
+import sys
+import inspect
 from py4j.protocol import Py4JJavaError
 
 __all__ = []
@@ -45,6 +48,19 @@ def _exception_message(excp):
     return str(excp)
 
 
+def _get_argspec(f):
+    """
+    Get argspec of a function. Supports both Python 2 and Python 3.
+    """
+    # `getargspec` is deprecated since python3.0 (incompatible with function annotations).
+    # See SPARK-23569.
+    if sys.version_info[0] < 3:
+        argspec = inspect.getargspec(f)
+    else:
+        argspec = inspect.getfullargspec(f)
+    return argspec
+
+
 if __name__ == "__main__":
     import doctest
     (failure_count, test_count) = doctest.testmod()

http://git-wip-us.apache.org/repos/asf/spark/blob/2cb23a8f/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 89a3a92..202cac3 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -34,6 +34,7 @@ from pyspark.serializers import write_with_length, write_int, read_long, \
     write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
     BatchedSerializer, ArrowStreamPandasSerializer
 from pyspark.sql.types import to_arrow_type
+from pyspark.util import _get_argspec
 from pyspark import shuffle
 
 pickleSer = PickleSerializer()
@@ -91,10 +92,16 @@ def wrap_scalar_pandas_udf(f, return_type):
 
 
 def wrap_grouped_map_pandas_udf(f, return_type):
-    def wrapped(*series):
+    def wrapped(key_series, value_series):
         import pandas as pd
+        argspec = _get_argspec(f)
+
+        if len(argspec.args) == 1:
+            result = f(pd.concat(value_series, axis=1))
+        elif len(argspec.args) == 2:
+            key = tuple(s[0] for s in key_series)
+            result = f(key, pd.concat(value_series, axis=1))
 
-        result = f(pd.concat(series, axis=1))
         if not isinstance(result, pd.DataFrame):
             raise TypeError("Return type of the user-defined function should be "
                             "pandas.DataFrame, but is {}".format(type(result)))
@@ -149,18 +156,36 @@ def read_udfs(pickleSer, infile, eval_type):
     num_udfs = read_int(infile)
     udfs = {}
     call_udf = []
-    for i in range(num_udfs):
+    mapper_str = ""
+    if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
+        # Create function like this:
+        #   lambda a: f([a[0]], [a[0], a[1]])
+
+        # We assume there is only one UDF here because grouped map doesn't
+        # support combining multiple UDFs.
+        assert num_udfs == 1
+
+        # See FlatMapGroupsInPandasExec for how arg_offsets are used to
+        # distinguish between grouping attributes and data attributes
         arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type)
-        udfs['f%d' % i] = udf
-        args = ["a[%d]" % o for o in arg_offsets]
-        call_udf.append("f%d(%s)" % (i, ", ".join(args)))
-    # Create function like this:
-    #   lambda a: (f0(a0), f1(a1, a2), f2(a3))
-    # In the special case of a single UDF this will return a single result rather
-    # than a tuple of results; this is the format that the JVM side expects.
-    mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
-    mapper = eval(mapper_str, udfs)
+        udfs['f'] = udf
+        split_offset = arg_offsets[0] + 1
+        arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]]
+        arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]]
+        mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0), ", ".join(arg1))
+    else:
+        # Create function like this:
+        #   lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3]))
+        # In the special case of a single UDF this will return a single result rather
+        # than a tuple of results; this is the format that the JVM side expects.
+        for i in range(num_udfs):
+            arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type)
+            udfs['f%d' % i] = udf
+            args = ["a[%d]" % o for o in arg_offsets]
+            call_udf.append("f%d(%s)" % (i, ", ".join(args)))
+        mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
 
+    mapper = eval(mapper_str, udfs)
     func = lambda _, it: map(mapper, it)
 
     if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,

http://git-wip-us.apache.org/repos/asf/spark/blob/2cb23a8f/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
index c798fe5..513e174 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.execution.python
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.TaskContext
 import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
@@ -75,20 +76,63 @@ case class FlatMapGroupsInPandasExec(
     val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
     val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
     val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
-    val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray)
-    val schema = StructType(child.schema.drop(groupingAttributes.length))
     val sessionLocalTimeZone = conf.sessionLocalTimeZone
     val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone
 
+    // Deduplicate the grouping attributes.
+    // If a grouping attribute also appears in data attributes, then we don't need to send the
+    // grouping attribute to Python worker. If a grouping attribute is not in data attributes,
+    // then we need to send this grouping attribute to python worker.
+    //
+    // We use argOffsets to distinguish grouping attributes and data attributes as following:
+    //
+    // argOffsets[0] is the length of grouping attributes
+    // argOffsets[1 .. argOffsets[0]+1] is the arg offsets for grouping attributes
+    // argOffsets[argOffsets[0]+1 .. ] is the arg offsets for data attributes
+
+    val dataAttributes = child.output.drop(groupingAttributes.length)
+    val groupingIndicesInData = groupingAttributes.map { attribute =>
+      dataAttributes.indexWhere(attribute.semanticEquals)
+    }
+
+    val groupingArgOffsets = new ArrayBuffer[Int]
+    val nonDupGroupingAttributes = new ArrayBuffer[Attribute]
+    val nonDupGroupingSize = groupingIndicesInData.count(_ == -1)
+
+    // Non duplicate grouping attributes are added to nonDupGroupingAttributes and
+    // their offsets are 0, 1, 2 ...
+    // Duplicate grouping attributes are NOT added to nonDupGroupingAttributes and
+    // their offsets are n + index, where n is the total number of non duplicate grouping
+    // attributes and index is the index in the data attributes that the grouping attribute
+    // is a duplicate of.
+
+    groupingAttributes.zip(groupingIndicesInData).foreach {
+      case (attribute, index) =>
+        if (index == -1) {
+          groupingArgOffsets += nonDupGroupingAttributes.length
+          nonDupGroupingAttributes += attribute
+        } else {
+          groupingArgOffsets += index + nonDupGroupingSize
+        }
+    }
+
+    val dataArgOffsets = nonDupGroupingAttributes.length until
+      (nonDupGroupingAttributes.length + dataAttributes.length)
+
+    val argOffsets = Array(Array(groupingAttributes.length) ++ groupingArgOffsets ++ dataArgOffsets)
+
+    // Attributes after deduplication
+    val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes
+    val dedupSchema = StructType.fromAttributes(dedupAttributes)
+
     inputRDD.mapPartitionsInternal { iter =>
       val grouped = if (groupingAttributes.isEmpty) {
         Iterator(iter)
       } else {
         val groupedIter = GroupedIterator(iter, groupingAttributes, child.output)
-        val dropGrouping =
-          UnsafeProjection.create(child.output.drop(groupingAttributes.length), child.output)
+        val dedupProj = UnsafeProjection.create(dedupAttributes, child.output)
         groupedIter.map {
-          case (_, groupedRowIter) => groupedRowIter.map(dropGrouping)
+          case (_, groupedRowIter) => groupedRowIter.map(dedupProj)
         }
       }
 
@@ -96,7 +140,7 @@ case class FlatMapGroupsInPandasExec(
 
       val columnarBatchIter = new ArrowPythonRunner(
         chainedFunc, bufferSize, reuseWorker,
-        PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, schema,
+        PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, dedupSchema,
         sessionLocalTimeZone, pandasRespectSessionTimeZone)
           .compute(grouped, context.partitionId(), context)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org