You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2019/03/21 08:45:19 UTC

[spark] branch master updated: [SPARK-27163][PYTHON] Cleanup and consolidate Pandas UDF functionality

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new be08b41  [SPARK-27163][PYTHON] Cleanup and consolidate Pandas UDF functionality
be08b41 is described below

commit be08b415dadd2725be2d2c8f8890a425fbad9338
Author: Bryan Cutler <cu...@gmail.com>
AuthorDate: Thu Mar 21 17:44:51 2019 +0900

    [SPARK-27163][PYTHON] Cleanup and consolidate Pandas UDF functionality
    
    ## What changes were proposed in this pull request?
    
    This change is a cleanup and consolidation of 3 areas related to Pandas UDFs:
    
    1) `ArrowStreamPandasSerializer` now inherits from `ArrowStreamSerializer` and uses the base class `dump_stream`, `load_stream` to create Arrow reader/writer and send Arrow record batches.  `ArrowStreamPandasSerializer` makes the conversions to/from Pandas and converts to Arrow record batch iterators. This change removed duplicated creation of Arrow readers/writers.
    
    2) `createDataFrame` with Arrow now uses `ArrowStreamPandasSerializer` instead of doing its own conversions from Pandas to Arrow and sending record batches through `ArrowStreamSerializer`.
    
    3) Grouped Map UDFs now reuse existing logic in `ArrowStreamPandasSerializer` to send Pandas DataFrame results as a `StructType` instead of separating each column from the DataFrame. This makes the code a little more consistent with the Python worker, but does require that the returned StructType column is flattened out in `FlatMapGroupsInPandasExec` in Scala.
    
    ## How was this patch tested?
    
    Existing tests and ran tests with pyarrow 0.12.0
    
    Closes #24095 from BryanCutler/arrow-refactor-cleanup-UDFs.
    
    Authored-by: Bryan Cutler <cu...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 python/pyspark/serializers.py                      | 218 +++++++++++----------
 python/pyspark/sql/session.py                      |  42 ++--
 python/pyspark/worker.py                           |  23 +--
 .../sql/execution/arrow/ArrowConverters.scala      |   1 -
 .../python/FlatMapGroupsInPandasExec.scala         |  12 +-
 5 files changed, 161 insertions(+), 135 deletions(-)

diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 0c3c68e..58f7552 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -245,92 +245,13 @@ class ArrowStreamSerializer(Serializer):
         return "ArrowStreamSerializer"
 
 
-def _create_batch(series, timezone, safecheck, assign_cols_by_name):
+class ArrowStreamPandasSerializer(ArrowStreamSerializer):
     """
-    Create an Arrow record batch from the given pandas.Series or list of Series, with optional type.
+    Serializes Pandas.Series as Arrow data with Arrow streaming format.
 
-    :param series: A single pandas.Series, list of Series, or list of (series, arrow_type)
     :param timezone: A timezone to respect when handling timestamp values
-    :return: Arrow RecordBatch
-    """
-    import decimal
-    from distutils.version import LooseVersion
-    import pandas as pd
-    import pyarrow as pa
-    from pyspark.sql.types import _check_series_convert_timestamps_internal
-    # Make input conform to [(series1, type1), (series2, type2), ...]
-    if not isinstance(series, (list, tuple)) or \
-            (len(series) == 2 and isinstance(series[1], pa.DataType)):
-        series = [series]
-    series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)
-
-    def create_array(s, t):
-        mask = s.isnull()
-        # Ensure timestamp series are in expected form for Spark internal representation
-        # TODO: maybe don't need None check anymore as of Arrow 0.9.1
-        if t is not None and pa.types.is_timestamp(t):
-            s = _check_series_convert_timestamps_internal(s.fillna(0), timezone)
-            # TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2
-            return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False)
-        elif t is not None and pa.types.is_string(t) and sys.version < '3':
-            # TODO: need decode before converting to Arrow in Python 2
-            # TODO: don't need as of Arrow 0.9.1
-            return pa.Array.from_pandas(s.apply(
-                lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t)
-        elif t is not None and pa.types.is_decimal(t) and \
-                LooseVersion("0.9.0") <= LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
-            # TODO: see ARROW-2432. Remove when the minimum PyArrow version becomes 0.10.0.
-            return pa.Array.from_pandas(s.apply(
-                lambda v: decimal.Decimal('NaN') if v is None else v), mask=mask, type=t)
-        elif LooseVersion(pa.__version__) < LooseVersion("0.11.0"):
-            # TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0.
-            return pa.Array.from_pandas(s, mask=mask, type=t)
-
-        try:
-            array = pa.Array.from_pandas(s, mask=mask, type=t, safe=safecheck)
-        except pa.ArrowException as e:
-            error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \
-                        "Array (%s). It can be caused by overflows or other unsafe " + \
-                        "conversions warned by Arrow. Arrow safe type check can be " + \
-                        "disabled by using SQL config " + \
-                        "`spark.sql.execution.pandas.arrowSafeTypeConversion`."
-            raise RuntimeError(error_msg % (s.dtype, t), e)
-        return array
-
-    arrs = []
-    for s, t in series:
-        if t is not None and pa.types.is_struct(t):
-            if not isinstance(s, pd.DataFrame):
-                raise ValueError("A field of type StructType expects a pandas.DataFrame, "
-                                 "but got: %s" % str(type(s)))
-
-            # Input partition and result pandas.DataFrame empty, make empty Arrays with struct
-            if len(s) == 0 and len(s.columns) == 0:
-                arrs_names = [(pa.array([], type=field.type), field.name) for field in t]
-            # Assign result columns by schema name if user labeled with strings
-            elif assign_cols_by_name and any(isinstance(name, basestring) for name in s.columns):
-                arrs_names = [(create_array(s[field.name], field.type), field.name) for field in t]
-            # Assign result columns by  position
-            else:
-                arrs_names = [(create_array(s[s.columns[i]], field.type), field.name)
-                              for i, field in enumerate(t)]
-
-            struct_arrs, struct_names = zip(*arrs_names)
-
-            # TODO: from_arrays args switched for v0.9.0, remove when bump minimum pyarrow version
-            if LooseVersion(pa.__version__) < LooseVersion("0.9.0"):
-                arrs.append(pa.StructArray.from_arrays(struct_names, struct_arrs))
-            else:
-                arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names))
-        else:
-            arrs.append(create_array(s, t))
-
-    return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])
-
-
-class ArrowStreamPandasSerializer(Serializer):
-    """
-    Serializes Pandas.Series as Arrow data with Arrow streaming format.
+    :param safecheck: If True, conversion from Arrow to Pandas checks for overflow/truncation
+    :param assign_cols_by_name: If True, then Pandas DataFrames will get columns by name
     """
 
     def __init__(self, timezone, safecheck, assign_cols_by_name):
@@ -347,39 +268,138 @@ class ArrowStreamPandasSerializer(Serializer):
         s = _check_series_localize_timestamps(s, self._timezone)
         return s
 
+    def _create_batch(self, series):
+        """
+        Create an Arrow record batch from the given pandas.Series or list of Series,
+        with optional type.
+
+        :param series: A single pandas.Series, list of Series, or list of (series, arrow_type)
+        :return: Arrow RecordBatch
+        """
+        import decimal
+        from distutils.version import LooseVersion
+        import pandas as pd
+        import pyarrow as pa
+        from pyspark.sql.types import _check_series_convert_timestamps_internal
+        # Make input conform to [(series1, type1), (series2, type2), ...]
+        if not isinstance(series, (list, tuple)) or \
+                (len(series) == 2 and isinstance(series[1], pa.DataType)):
+            series = [series]
+        series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)
+
+        def create_array(s, t):
+            mask = s.isnull()
+            # Ensure timestamp series are in expected form for Spark internal representation
+            # TODO: maybe don't need None check anymore as of Arrow 0.9.1
+            if t is not None and pa.types.is_timestamp(t):
+                s = _check_series_convert_timestamps_internal(s.fillna(0), self._timezone)
+                # TODO: need cast after Arrow conversion, ns values cause error with pandas 0.19.2
+                return pa.Array.from_pandas(s, mask=mask).cast(t, safe=False)
+            elif t is not None and pa.types.is_string(t) and sys.version < '3':
+                # TODO: need decode before converting to Arrow in Python 2
+                # TODO: don't need as of Arrow 0.9.1
+                return pa.Array.from_pandas(s.apply(
+                    lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t)
+            elif t is not None and pa.types.is_decimal(t) and \
+                    LooseVersion("0.9.0") <= LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
+                # TODO: see ARROW-2432. Remove when the minimum PyArrow version becomes 0.10.0.
+                return pa.Array.from_pandas(s.apply(
+                    lambda v: decimal.Decimal('NaN') if v is None else v), mask=mask, type=t)
+            elif LooseVersion(pa.__version__) < LooseVersion("0.11.0"):
+                # TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0.
+                return pa.Array.from_pandas(s, mask=mask, type=t)
+
+            try:
+                array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck)
+            except pa.ArrowException as e:
+                error_msg = "Exception thrown when converting pandas.Series (%s) to Arrow " + \
+                            "Array (%s). It can be caused by overflows or other unsafe " + \
+                            "conversions warned by Arrow. Arrow safe type check can be " + \
+                            "disabled by using SQL config " + \
+                            "`spark.sql.execution.pandas.arrowSafeTypeConversion`."
+                raise RuntimeError(error_msg % (s.dtype, t), e)
+            return array
+
+        arrs = []
+        for s, t in series:
+            if t is not None and pa.types.is_struct(t):
+                if not isinstance(s, pd.DataFrame):
+                    raise ValueError("A field of type StructType expects a pandas.DataFrame, "
+                                     "but got: %s" % str(type(s)))
+
+                # Input partition and result pandas.DataFrame empty, make empty Arrays with struct
+                if len(s) == 0 and len(s.columns) == 0:
+                    arrs_names = [(pa.array([], type=field.type), field.name) for field in t]
+                # Assign result columns by schema name if user labeled with strings
+                elif self._assign_cols_by_name and any(isinstance(name, basestring)
+                                                       for name in s.columns):
+                    arrs_names = [(create_array(s[field.name], field.type), field.name)
+                                  for field in t]
+                # Assign result columns by  position
+                else:
+                    arrs_names = [(create_array(s[s.columns[i]], field.type), field.name)
+                                  for i, field in enumerate(t)]
+
+                struct_arrs, struct_names = zip(*arrs_names)
+
+                # TODO: from_arrays args switched for v0.9.0, remove when bump min pyarrow version
+                if LooseVersion(pa.__version__) < LooseVersion("0.9.0"):
+                    arrs.append(pa.StructArray.from_arrays(struct_names, struct_arrs))
+                else:
+                    arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names))
+            else:
+                arrs.append(create_array(s, t))
+
+        return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])
+
     def dump_stream(self, iterator, stream):
         """
         Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or
         a list of series accompanied by an optional pyarrow type to coerce the data to.
         """
-        import pyarrow as pa
-        writer = None
-        try:
-            for series in iterator:
-                batch = _create_batch(series, self._timezone, self._safecheck,
-                                      self._assign_cols_by_name)
-                if writer is None:
-                    write_int(SpecialLengths.START_ARROW_STREAM, stream)
-                    writer = pa.RecordBatchStreamWriter(stream, batch.schema)
-                writer.write_batch(batch)
-        finally:
-            if writer is not None:
-                writer.close()
+        batches = (self._create_batch(series) for series in iterator)
+        super(ArrowStreamPandasSerializer, self).dump_stream(batches, stream)
 
     def load_stream(self, stream):
         """
         Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
         """
+        batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
         import pyarrow as pa
-        reader = pa.ipc.open_stream(stream)
-
-        for batch in reader:
+        for batch in batches:
             yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()]
 
     def __repr__(self):
         return "ArrowStreamPandasSerializer"
 
 
+class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
+    """
+    Serializer used by Python worker to evaluate Pandas UDFs
+    """
+
+    def dump_stream(self, iterator, stream):
+        """
+        Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent.
+        This should be sent after creating the first record batch so in case of an error, it can
+        be sent back to the JVM before the Arrow stream starts.
+        """
+
+        def init_stream_yield_batches():
+            should_write_start_length = True
+            for series in iterator:
+                batch = self._create_batch(series)
+                if should_write_start_length:
+                    write_int(SpecialLengths.START_ARROW_STREAM, stream)
+                    should_write_start_length = False
+                yield batch
+
+        return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream)
+
+    def __repr__(self):
+        return "ArrowStreamPandasUDFSerializer"
+
+
 class BatchedSerializer(Serializer):
 
     """
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 32a2c8a..b11e0f3 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -530,8 +530,9 @@ class SparkSession(object):
         to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
         data types will be used to coerce the data in Pandas to Arrow conversion.
         """
-        from pyspark.serializers import ArrowStreamSerializer, _create_batch
-        from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType
+        from distutils.version import LooseVersion
+        from pyspark.serializers import ArrowStreamPandasSerializer
+        from pyspark.sql.types import from_arrow_type, to_arrow_type, TimestampType
         from pyspark.sql.utils import require_minimum_pandas_version, \
             require_minimum_pyarrow_version
 
@@ -539,6 +540,19 @@ class SparkSession(object):
         require_minimum_pyarrow_version()
 
         from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
+        import pyarrow as pa
+
+        # Create the Spark schema from list of names passed in with Arrow types
+        if isinstance(schema, (list, tuple)):
+            if LooseVersion(pa.__version__) < LooseVersion("0.12.0"):
+                temp_batch = pa.RecordBatch.from_pandas(pdf[0:100], preserve_index=False)
+                arrow_schema = temp_batch.schema
+            else:
+                arrow_schema = pa.Schema.from_pandas(pdf, preserve_index=False)
+            struct = StructType()
+            for name, field in zip(schema, arrow_schema):
+                struct.add(name, from_arrow_type(field.type), nullable=field.nullable)
+            schema = struct
 
         # Determine arrow types to coerce data when creating batches
         if isinstance(schema, StructType):
@@ -555,23 +569,16 @@ class SparkSession(object):
         step = -(-len(pdf) // self.sparkContext.defaultParallelism)  # round int up
         pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step))
 
-        # Create Arrow record batches
-        safecheck = self._wrapped._conf.arrowSafeTypeConversion()
-        col_by_name = True  # col by name only applies to StructType columns, can't happen here
-        batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)],
-                                 timezone, safecheck, col_by_name)
-                   for pdf_slice in pdf_slices]
-
-        # Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing)
-        if isinstance(schema, (list, tuple)):
-            struct = from_arrow_schema(batches[0].schema)
-            for i, name in enumerate(schema):
-                struct.fields[i].name = name
-                struct.names[i] = name
-            schema = struct
+        # Create list of Arrow (columns, type) for serializer dump_stream
+        arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]
+                      for pdf_slice in pdf_slices]
 
         jsqlContext = self._wrapped._jsqlContext
 
+        safecheck = self._wrapped._conf.arrowSafeTypeConversion()
+        col_by_name = True  # col by name only applies to StructType columns, can't happen here
+        ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name)
+
         def reader_func(temp_filename):
             return self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsqlContext, temp_filename)
 
@@ -579,8 +586,7 @@ class SparkSession(object):
             return self._jvm.ArrowRDDServer(jsqlContext)
 
         # Create Spark DataFrame from Arrow stream file, using one batch per partition
-        jrdd = self._sc._serialize_to_jvm(batches, ArrowStreamSerializer(), reader_func,
-                                          create_RDD_server)
+        jrdd = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_RDD_server)
         jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsqlContext)
         df = DataFrame(jdf, self._wrapped)
         df._schema = schema
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 7811012..0e3bef6 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -38,7 +38,7 @@ from pyspark.files import SparkFiles
 from pyspark.rdd import PythonEvalType
 from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \
     write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
-    BatchedSerializer, ArrowStreamPandasSerializer
+    BatchedSerializer, ArrowStreamPandasUDFSerializer
 from pyspark.sql.types import to_arrow_type, StructType
 from pyspark.util import _get_argspec, fail_on_stopiteration
 from pyspark import shuffle
@@ -103,10 +103,7 @@ def wrap_scalar_pandas_udf(f, return_type):
     return lambda *a: (verify_result_length(*a), arrow_return_type)
 
 
-def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf):
-    assign_cols_by_name = runner_conf.get(
-        "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")
-    assign_cols_by_name = assign_cols_by_name.lower() == "true"
+def wrap_grouped_map_pandas_udf(f, return_type, argspec):
 
     def wrapped(key_series, value_series):
         import pandas as pd
@@ -125,15 +122,9 @@ def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf):
                 "Number of columns of the returned pandas.DataFrame "
                 "doesn't match specified schema. "
                 "Expected: {} Actual: {}".format(len(return_type), len(result.columns)))
+        return result
 
-        # Assign result columns by schema name if user labeled with strings, else use position
-        if assign_cols_by_name and any(isinstance(name, basestring) for name in result.columns):
-            return [(result[field.name], to_arrow_type(field.dataType)) for field in return_type]
-        else:
-            return [(result[result.columns[i]], to_arrow_type(field.dataType))
-                    for i, field in enumerate(return_type)]
-
-    return wrapped
+    return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))]
 
 
 def wrap_grouped_agg_pandas_udf(f, return_type):
@@ -227,7 +218,7 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
         return arg_offsets, wrap_scalar_pandas_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
         argspec = _get_argspec(row_func)  # signature was lost when wrapping it
-        return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec, runner_conf)
+        return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec)
     elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
         return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
@@ -257,12 +248,12 @@ def read_udfs(pickleSer, infile, eval_type):
         timezone = runner_conf.get("spark.sql.session.timeZone", None)
         safecheck = runner_conf.get("spark.sql.execution.pandas.arrowSafeTypeConversion",
                                     "false").lower() == 'true'
-        # NOTE: this is duplicated from wrap_grouped_map_pandas_udf
+        # Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF when returning StructType
         assign_cols_by_name = runner_conf.get(
             "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true")\
             .lower() == "true"
 
-        ser = ArrowStreamPandasSerializer(timezone, safecheck, assign_cols_by_name)
+        ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name)
     else:
         ser = BatchedSerializer(PickleSerializer(), 100)
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
index 2bf6a58..884dc8c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -31,7 +31,6 @@ import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer}
 import org.apache.spark.TaskContext
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.network.util.JavaUtils
-import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, SQLContext}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.types._
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
index e9cff1a..ce755ff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistrib
 import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
 import org.apache.spark.sql.execution.arrow.ArrowUtils
 import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
 
 /**
  * Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandas]]
@@ -145,7 +146,16 @@ case class FlatMapGroupsInPandasExec(
         sessionLocalTimeZone,
         pythonRunnerConf).compute(grouped, context.partitionId(), context)
 
-      columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output))
+      val unsafeProj = UnsafeProjection.create(output, output)
+
+      columnarBatchIter.flatMap { batch =>
+        // Grouped Map UDF returns a StructType column in ColumnarBatch, select the children here
+        val structVector = batch.column(0).asInstanceOf[ArrowColumnVector]
+        val outputVectors = output.indices.map(structVector.getChild)
+        val flattenedBatch = new ColumnarBatch(outputVectors.toArray)
+        flattenedBatch.setNumRows(batch.numRows())
+        flattenedBatch.rowIterator.asScala
+      }.map(unsafeProj)
     }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org