You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2018/03/08 11:29:11 UTC
spark git commit: [SPARK-23011][SQL][PYTHON] Support alternative
function form with group aggregate pandas UDF
Repository: spark
Updated Branches:
refs/heads/master d6632d185 -> 2cb23a8f5
[SPARK-23011][SQL][PYTHON] Support alternative function form with group aggregate pandas UDF
## What changes were proposed in this pull request?
This PR proposes to support an alternative function from with group aggregate pandas UDF.
The current form:
```
def foo(pdf):
return ...
```
Takes a single arg that is a pandas DataFrame.
With this PR, an alternative form is supported:
```
def foo(key, pdf):
return ...
```
The alternative form takes two argument - a tuple that presents the grouping key, and a pandas DataFrame represents the data.
## How was this patch tested?
GroupbyApplyTests
Author: Li Jin <ic...@gmail.com>
Closes #20295 from icexelloss/SPARK-23011-groupby-apply-key.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2cb23a8f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2cb23a8f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2cb23a8f
Branch: refs/heads/master
Commit: 2cb23a8f51a151970c121015fcbad9beeafa8295
Parents: d6632d1
Author: Li Jin <ic...@gmail.com>
Authored: Thu Mar 8 20:29:07 2018 +0900
Committer: hyukjinkwon <gu...@gmail.com>
Committed: Thu Mar 8 20:29:07 2018 +0900
----------------------------------------------------------------------
python/pyspark/serializers.py | 18 +--
python/pyspark/sql/functions.py | 25 ++++
python/pyspark/sql/tests.py | 121 +++++++++++++++++--
python/pyspark/sql/types.py | 45 +++++--
python/pyspark/sql/udf.py | 19 ++-
python/pyspark/util.py | 16 +++
python/pyspark/worker.py | 49 ++++++--
.../python/FlatMapGroupsInPandasExec.scala | 56 ++++++++-
8 files changed, 294 insertions(+), 55 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/2cb23a8f/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 917e258..ebf5493 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -250,6 +250,15 @@ class ArrowStreamPandasSerializer(Serializer):
super(ArrowStreamPandasSerializer, self).__init__()
self._timezone = timezone
+ def arrow_to_pandas(self, arrow_column):
+ from pyspark.sql.types import from_arrow_type, \
+ _check_series_convert_date, _check_series_localize_timestamps
+
+ s = arrow_column.to_pandas()
+ s = _check_series_convert_date(s, from_arrow_type(arrow_column.type))
+ s = _check_series_localize_timestamps(s, self._timezone)
+ return s
+
def dump_stream(self, iterator, stream):
"""
Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or
@@ -272,16 +281,11 @@ class ArrowStreamPandasSerializer(Serializer):
"""
Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
"""
- from pyspark.sql.types import from_arrow_schema, _check_dataframe_convert_date, \
- _check_dataframe_localize_timestamps
import pyarrow as pa
reader = pa.open_stream(stream)
- schema = from_arrow_schema(reader.schema)
+
for batch in reader:
- pdf = batch.to_pandas()
- pdf = _check_dataframe_convert_date(pdf, schema)
- pdf = _check_dataframe_localize_timestamps(pdf, self._timezone)
- yield [c for _, c in pdf.iteritems()]
+ yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()]
def __repr__(self):
return "ArrowStreamPandasSerializer"
http://git-wip-us.apache.org/repos/asf/spark/blob/2cb23a8f/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index b9c0c57..dc1341a 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2267,6 +2267,31 @@ def pandas_udf(f=None, returnType=None, functionType=None):
| 2| 1.1094003924504583|
+---+-------------------+
+ Alternatively, the user can define a function that takes two arguments.
+ In this case, the grouping key will be passed as the first argument and the data will
+ be passed as the second argument. The grouping key will be passed as a tuple of numpy
+ data types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in
+ as a `pandas.DataFrame` containing all columns from the original Spark DataFrame.
+ This is useful when the user does not want to hardcode grouping key in the function.
+
+ >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
+ >>> import pandas as pd # doctest: +SKIP
+ >>> df = spark.createDataFrame(
+ ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
+ ... ("id", "v")) # doctest: +SKIP
+ >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP
+ ... def mean_udf(key, pdf):
+ ... # key is a tuple of one numpy.int64, which is the value
+ ... # of 'id' for the current group
+ ... return pd.DataFrame([key + (pdf.v.mean(),)])
+ >>> df.groupby('id').apply(mean_udf).show() # doctest: +SKIP
+ +---+---+
+ | id| v|
+ +---+---+
+ | 1|1.5|
+ | 2|6.0|
+ +---+---+
+
.. seealso:: :meth:`pyspark.sql.GroupedData.apply`
3. GROUPED_AGG
http://git-wip-us.apache.org/repos/asf/spark/blob/2cb23a8f/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index a9fe0b4..480815d 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3903,7 +3903,7 @@ class PandasUDFTests(ReusedSQLTestCase):
return df
with self.assertRaisesRegexp(ValueError, 'Invalid function'):
@pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUPED_MAP)
- def foo(k, v):
+ def foo(k, v, w):
return k
@@ -4476,20 +4476,45 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col
df = self.data.withColumn("arr", array(col("id")))
- foo_udf = pandas_udf(
+ # Different forms of group map pandas UDF, results of these are the same
+
+ output_schema = StructType(
+ [StructField('id', LongType()),
+ StructField('v', IntegerType()),
+ StructField('arr', ArrayType(LongType())),
+ StructField('v1', DoubleType()),
+ StructField('v2', LongType())])
+
+ udf1 = pandas_udf(
lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
- StructType(
- [StructField('id', LongType()),
- StructField('v', IntegerType()),
- StructField('arr', ArrayType(LongType())),
- StructField('v1', DoubleType()),
- StructField('v2', LongType())]),
+ output_schema,
PandasUDFType.GROUPED_MAP
)
- result = df.groupby('id').apply(foo_udf).sort('id').toPandas()
- expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True)
- self.assertPandasEqual(expected, result)
+ udf2 = pandas_udf(
+ lambda _, pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
+ output_schema,
+ PandasUDFType.GROUPED_MAP
+ )
+
+ udf3 = pandas_udf(
+ lambda key, pdf: pdf.assign(id=key[0], v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
+ output_schema,
+ PandasUDFType.GROUPED_MAP
+ )
+
+ result1 = df.groupby('id').apply(udf1).sort('id').toPandas()
+ expected1 = df.toPandas().groupby('id').apply(udf1.func).reset_index(drop=True)
+
+ result2 = df.groupby('id').apply(udf2).sort('id').toPandas()
+ expected2 = expected1
+
+ result3 = df.groupby('id').apply(udf3).sort('id').toPandas()
+ expected3 = expected1
+
+ self.assertPandasEqual(expected1, result1)
+ self.assertPandasEqual(expected2, result2)
+ self.assertPandasEqual(expected3, result3)
def test_register_grouped_map_udf(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType
@@ -4648,6 +4673,80 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
result = df.groupby('time').apply(foo_udf).sort('time')
self.assertPandasEqual(df.toPandas(), result.toPandas())
+ def test_udf_with_key(self):
+ from pyspark.sql.functions import pandas_udf, col, PandasUDFType
+ df = self.data
+ pdf = df.toPandas()
+
+ def foo1(key, pdf):
+ import numpy as np
+ assert type(key) == tuple
+ assert type(key[0]) == np.int64
+
+ return pdf.assign(v1=key[0],
+ v2=pdf.v * key[0],
+ v3=pdf.v * pdf.id,
+ v4=pdf.v * pdf.id.mean())
+
+ def foo2(key, pdf):
+ import numpy as np
+ assert type(key) == tuple
+ assert type(key[0]) == np.int64
+ assert type(key[1]) == np.int32
+
+ return pdf.assign(v1=key[0],
+ v2=key[1],
+ v3=pdf.v * key[0],
+ v4=pdf.v + key[1])
+
+ def foo3(key, pdf):
+ assert type(key) == tuple
+ assert len(key) == 0
+ return pdf.assign(v1=pdf.v * pdf.id)
+
+ # v2 is int because numpy.int64 * pd.Series<int32> results in pd.Series<int32>
+ # v3 is long because pd.Series<int64> * pd.Series<int32> results in pd.Series<int64>
+ udf1 = pandas_udf(
+ foo1,
+ 'id long, v int, v1 long, v2 int, v3 long, v4 double',
+ PandasUDFType.GROUPED_MAP)
+
+ udf2 = pandas_udf(
+ foo2,
+ 'id long, v int, v1 long, v2 int, v3 int, v4 int',
+ PandasUDFType.GROUPED_MAP)
+
+ udf3 = pandas_udf(
+ foo3,
+ 'id long, v int, v1 long',
+ PandasUDFType.GROUPED_MAP)
+
+ # Test groupby column
+ result1 = df.groupby('id').apply(udf1).sort('id', 'v').toPandas()
+ expected1 = pdf.groupby('id')\
+ .apply(lambda x: udf1.func((x.id.iloc[0],), x))\
+ .sort_values(['id', 'v']).reset_index(drop=True)
+ self.assertPandasEqual(expected1, result1)
+
+ # Test groupby expression
+ result2 = df.groupby(df.id % 2).apply(udf1).sort('id', 'v').toPandas()
+ expected2 = pdf.groupby(pdf.id % 2)\
+ .apply(lambda x: udf1.func((x.id.iloc[0] % 2,), x))\
+ .sort_values(['id', 'v']).reset_index(drop=True)
+ self.assertPandasEqual(expected2, result2)
+
+ # Test complex groupby
+ result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort('id', 'v').toPandas()
+ expected3 = pdf.groupby([pdf.id, pdf.v % 2])\
+ .apply(lambda x: udf2.func((x.id.iloc[0], (x.v % 2).iloc[0],), x))\
+ .sort_values(['id', 'v']).reset_index(drop=True)
+ self.assertPandasEqual(expected3, result3)
+
+ # Test empty groupby
+ result4 = df.groupby().apply(udf3).sort('id', 'v').toPandas()
+ expected4 = udf3.func((), pdf)
+ self.assertPandasEqual(expected4, result4)
+
@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
http://git-wip-us.apache.org/repos/asf/spark/blob/2cb23a8f/python/pyspark/sql/types.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index cd85740..1632862 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1695,6 +1695,19 @@ def from_arrow_schema(arrow_schema):
for field in arrow_schema])
+def _check_series_convert_date(series, data_type):
+ """
+ Cast the series to datetime.date if it's a date type, otherwise returns the original series.
+
+ :param series: pandas.Series
+ :param data_type: a Spark data type for the series
+ """
+ if type(data_type) == DateType:
+ return series.dt.date
+ else:
+ return series
+
+
def _check_dataframe_convert_date(pdf, schema):
""" Correct date type value to use datetime.date.
@@ -1705,8 +1718,7 @@ def _check_dataframe_convert_date(pdf, schema):
:param schema: a Spark schema of the pandas.DataFrame
"""
for field in schema:
- if type(field.dataType) == DateType:
- pdf[field.name] = pdf[field.name].dt.date
+ pdf[field.name] = _check_series_convert_date(pdf[field.name], field.dataType)
return pdf
@@ -1725,6 +1737,29 @@ def _get_local_timezone():
return os.environ.get('TZ', 'dateutil/:')
+def _check_series_localize_timestamps(s, timezone):
+ """
+ Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone.
+
+ If the input series is not a timestamp series, then the same series is returned. If the input
+ series is a timestamp series, then a converted series is returned.
+
+ :param s: pandas.Series
+ :param timezone: the timezone to convert. if None then use local timezone
+ :return pandas.Series that have been converted to tz-naive
+ """
+ from pyspark.sql.utils import require_minimum_pandas_version
+ require_minimum_pandas_version()
+
+ from pandas.api.types import is_datetime64tz_dtype
+ tz = timezone or _get_local_timezone()
+ # TODO: handle nested timestamps, such as ArrayType(TimestampType())?
+ if is_datetime64tz_dtype(s.dtype):
+ return s.dt.tz_convert(tz).dt.tz_localize(None)
+ else:
+ return s
+
+
def _check_dataframe_localize_timestamps(pdf, timezone):
"""
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone
@@ -1736,12 +1771,8 @@ def _check_dataframe_localize_timestamps(pdf, timezone):
from pyspark.sql.utils import require_minimum_pandas_version
require_minimum_pandas_version()
- from pandas.api.types import is_datetime64tz_dtype
- tz = timezone or _get_local_timezone()
for column, series in pdf.iteritems():
- # TODO: handle nested timestamps, such as ArrayType(TimestampType())?
- if is_datetime64tz_dtype(series.dtype):
- pdf[column] = series.dt.tz_convert(tz).dt.tz_localize(None)
+ pdf[column] = _check_series_localize_timestamps(series, timezone)
return pdf
http://git-wip-us.apache.org/repos/asf/spark/blob/2cb23a8f/python/pyspark/sql/udf.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index b9b4908..ce804c1 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -17,6 +17,8 @@
"""
User-defined function related classes and functions
"""
+import sys
+import inspect
import functools
from pyspark import SparkContext, since
@@ -24,6 +26,7 @@ from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_
from pyspark.sql.column import Column, _to_java_column, _to_seq
from pyspark.sql.types import StringType, DataType, ArrayType, StructType, MapType, \
_parse_datatype_string, to_arrow_type, to_arrow_schema
+from pyspark.util import _get_argspec
__all__ = ["UDFRegistration"]
@@ -41,18 +44,10 @@ def _create_udf(f, returnType, evalType):
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF):
- import inspect
- import sys
from pyspark.sql.utils import require_minimum_pyarrow_version
-
require_minimum_pyarrow_version()
- if sys.version_info[0] < 3:
- # `getargspec` is deprecated since python3.0 (incompatible with function annotations).
- # See SPARK-23569.
- argspec = inspect.getargspec(f)
- else:
- argspec = inspect.getfullargspec(f)
+ argspec = _get_argspec(f)
if evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF and len(argspec.args) == 0 and \
argspec.varargs is None:
@@ -61,11 +56,11 @@ def _create_udf(f, returnType, evalType):
"Instead, create a 1-arg pandas_udf and ignore the arg in your function."
)
- if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF and len(argspec.args) != 1:
+ if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \
+ and len(argspec.args) not in (1, 2):
raise ValueError(
"Invalid function: pandas_udfs with function type GROUPED_MAP "
- "must take a single arg that is a pandas DataFrame."
- )
+ "must take either one argument (data) or two arguments (key, data).")
# Set the name of the UserDefinedFunction object to be the name of function f
udf_obj = UserDefinedFunction(
http://git-wip-us.apache.org/repos/asf/spark/blob/2cb23a8f/python/pyspark/util.py
----------------------------------------------------------------------
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index ad4a0bc..6837b18 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -15,6 +15,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+
+import sys
+import inspect
from py4j.protocol import Py4JJavaError
__all__ = []
@@ -45,6 +48,19 @@ def _exception_message(excp):
return str(excp)
+def _get_argspec(f):
+ """
+ Get argspec of a function. Supports both Python 2 and Python 3.
+ """
+ # `getargspec` is deprecated since python3.0 (incompatible with function annotations).
+ # See SPARK-23569.
+ if sys.version_info[0] < 3:
+ argspec = inspect.getargspec(f)
+ else:
+ argspec = inspect.getfullargspec(f)
+ return argspec
+
+
if __name__ == "__main__":
import doctest
(failure_count, test_count) = doctest.testmod()
http://git-wip-us.apache.org/repos/asf/spark/blob/2cb23a8f/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 89a3a92..202cac3 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -34,6 +34,7 @@ from pyspark.serializers import write_with_length, write_int, read_long, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
BatchedSerializer, ArrowStreamPandasSerializer
from pyspark.sql.types import to_arrow_type
+from pyspark.util import _get_argspec
from pyspark import shuffle
pickleSer = PickleSerializer()
@@ -91,10 +92,16 @@ def wrap_scalar_pandas_udf(f, return_type):
def wrap_grouped_map_pandas_udf(f, return_type):
- def wrapped(*series):
+ def wrapped(key_series, value_series):
import pandas as pd
+ argspec = _get_argspec(f)
+
+ if len(argspec.args) == 1:
+ result = f(pd.concat(value_series, axis=1))
+ elif len(argspec.args) == 2:
+ key = tuple(s[0] for s in key_series)
+ result = f(key, pd.concat(value_series, axis=1))
- result = f(pd.concat(series, axis=1))
if not isinstance(result, pd.DataFrame):
raise TypeError("Return type of the user-defined function should be "
"pandas.DataFrame, but is {}".format(type(result)))
@@ -149,18 +156,36 @@ def read_udfs(pickleSer, infile, eval_type):
num_udfs = read_int(infile)
udfs = {}
call_udf = []
- for i in range(num_udfs):
+ mapper_str = ""
+ if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
+ # Create function like this:
+ # lambda a: f([a[0]], [a[0], a[1]])
+
+ # We assume there is only one UDF here because grouped map doesn't
+ # support combining multiple UDFs.
+ assert num_udfs == 1
+
+ # See FlatMapGroupsInPandasExec for how arg_offsets are used to
+ # distinguish between grouping attributes and data attributes
arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type)
- udfs['f%d' % i] = udf
- args = ["a[%d]" % o for o in arg_offsets]
- call_udf.append("f%d(%s)" % (i, ", ".join(args)))
- # Create function like this:
- # lambda a: (f0(a0), f1(a1, a2), f2(a3))
- # In the special case of a single UDF this will return a single result rather
- # than a tuple of results; this is the format that the JVM side expects.
- mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
- mapper = eval(mapper_str, udfs)
+ udfs['f'] = udf
+ split_offset = arg_offsets[0] + 1
+ arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]]
+ arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]]
+ mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0), ", ".join(arg1))
+ else:
+ # Create function like this:
+ # lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3]))
+ # In the special case of a single UDF this will return a single result rather
+ # than a tuple of results; this is the format that the JVM side expects.
+ for i in range(num_udfs):
+ arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type)
+ udfs['f%d' % i] = udf
+ args = ["a[%d]" % o for o in arg_offsets]
+ call_udf.append("f%d(%s)" % (i, ", ".join(args)))
+ mapper_str = "lambda a: (%s)" % (", ".join(call_udf))
+ mapper = eval(mapper_str, udfs)
func = lambda _, it: map(mapper, it)
if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
http://git-wip-us.apache.org/repos/asf/spark/blob/2cb23a8f/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
index c798fe5..513e174 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.python
import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
import org.apache.spark.TaskContext
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
@@ -75,20 +76,63 @@ case class FlatMapGroupsInPandasExec(
val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)
val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction)))
- val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray)
- val schema = StructType(child.schema.drop(groupingAttributes.length))
val sessionLocalTimeZone = conf.sessionLocalTimeZone
val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone
+ // Deduplicate the grouping attributes.
+ // If a grouping attribute also appears in data attributes, then we don't need to send the
+ // grouping attribute to Python worker. If a grouping attribute is not in data attributes,
+ // then we need to send this grouping attribute to python worker.
+ //
+ // We use argOffsets to distinguish grouping attributes and data attributes as following:
+ //
+ // argOffsets[0] is the length of grouping attributes
+ // argOffsets[1 .. argOffsets[0]+1] is the arg offsets for grouping attributes
+ // argOffsets[argOffsets[0]+1 .. ] is the arg offsets for data attributes
+
+ val dataAttributes = child.output.drop(groupingAttributes.length)
+ val groupingIndicesInData = groupingAttributes.map { attribute =>
+ dataAttributes.indexWhere(attribute.semanticEquals)
+ }
+
+ val groupingArgOffsets = new ArrayBuffer[Int]
+ val nonDupGroupingAttributes = new ArrayBuffer[Attribute]
+ val nonDupGroupingSize = groupingIndicesInData.count(_ == -1)
+
+ // Non duplicate grouping attributes are added to nonDupGroupingAttributes and
+ // their offsets are 0, 1, 2 ...
+ // Duplicate grouping attributes are NOT added to nonDupGroupingAttributes and
+ // their offsets are n + index, where n is the total number of non duplicate grouping
+ // attributes and index is the index in the data attributes that the grouping attribute
+ // is a duplicate of.
+
+ groupingAttributes.zip(groupingIndicesInData).foreach {
+ case (attribute, index) =>
+ if (index == -1) {
+ groupingArgOffsets += nonDupGroupingAttributes.length
+ nonDupGroupingAttributes += attribute
+ } else {
+ groupingArgOffsets += index + nonDupGroupingSize
+ }
+ }
+
+ val dataArgOffsets = nonDupGroupingAttributes.length until
+ (nonDupGroupingAttributes.length + dataAttributes.length)
+
+ val argOffsets = Array(Array(groupingAttributes.length) ++ groupingArgOffsets ++ dataArgOffsets)
+
+ // Attributes after deduplication
+ val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes
+ val dedupSchema = StructType.fromAttributes(dedupAttributes)
+
inputRDD.mapPartitionsInternal { iter =>
val grouped = if (groupingAttributes.isEmpty) {
Iterator(iter)
} else {
val groupedIter = GroupedIterator(iter, groupingAttributes, child.output)
- val dropGrouping =
- UnsafeProjection.create(child.output.drop(groupingAttributes.length), child.output)
+ val dedupProj = UnsafeProjection.create(dedupAttributes, child.output)
groupedIter.map {
- case (_, groupedRowIter) => groupedRowIter.map(dropGrouping)
+ case (_, groupedRowIter) => groupedRowIter.map(dedupProj)
}
}
@@ -96,7 +140,7 @@ case class FlatMapGroupsInPandasExec(
val columnarBatchIter = new ArrowPythonRunner(
chainedFunc, bufferSize, reuseWorker,
- PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, schema,
+ PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, dedupSchema,
sessionLocalTimeZone, pandasRespectSessionTimeZone)
.compute(grouped, context.partitionId(), context)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org