You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2023/03/27 09:50:48 UTC

[spark] branch master updated: [SPARK-42929] make mapInPandas / mapInArrow support "is_barrier"

This is an automated email from the ASF dual-hosted git repository.

weichenxu123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2a1ac07132b [SPARK-42929] make mapInPandas / mapInArrow support "is_barrier"
2a1ac07132b is described below

commit 2a1ac07132b7abc13e56b9a632b3dece7e4b60ea
Author: Weichen Xu <we...@databricks.com>
AuthorDate: Mon Mar 27 17:50:23 2023 +0800

    [SPARK-42929] make mapInPandas / mapInArrow support "is_barrier"
    
    ### What changes were proposed in this pull request?
    
    make mapInPandas / mapInArrow support "is_barrier"
    
    ### Why are the changes needed?
    
    feature parity.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes.
    
    ### How was this patch tested?
    
    Manually:
    
    `bin/pyspark --remote local`:
    
    ```
    from pyspark.sql.functions import pandas_udf
    df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
    def filter_func(iterator):
        for pdf in iterator:
            yield pdf[pdf.id == 1]
    df.mapInPandas(filter_func, df.schema,  is_barrier=True).collect()
    
    def filter_func(iterator):
        for batch in iterator:
            pdf = batch.to_pandas()
            yield pyarrow.RecordBatch.from_pandas(pdf[pdf.id == 1])
    
    df.mapInArrow(filter_func, df.schema, is_barrier=True).collect()
    ```
    
    Closes #40559 from WeichenXu123/spark-connect-barrier-mode.
    
    Authored-by: Weichen Xu <we...@databricks.com>
    Signed-off-by: Weichen Xu <we...@databricks.com>
---
 .../main/protobuf/spark/connect/relations.proto    |  3 +++
 .../sql/connect/planner/SparkConnectPlanner.scala  |  5 ++--
 python/pyspark/sql/connect/dataframe.py            | 21 +++++++++++----
 python/pyspark/sql/connect/plan.py                 |  8 +++++-
 python/pyspark/sql/connect/proto/relations_pb2.py  | 24 ++++++++---------
 python/pyspark/sql/connect/proto/relations_pb2.pyi | 31 ++++++++++++++++++++--
 6 files changed, 70 insertions(+), 22 deletions(-)

diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 976bd68e7fe..c965a6c8d32 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -794,6 +794,9 @@ message MapPartitions {
 
   // (Required) Input user-defined function.
   CommonInlineUserDefinedFunction func = 2;
+
+  // (Optional) isBarrier.
+  optional bool is_barrier = 3;
 }
 
 message GroupMap {
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index e7911ccdf11..e7e88cab643 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -484,19 +484,20 @@ class SparkConnectPlanner(val session: SparkSession) {
   private def transformMapPartitions(rel: proto.MapPartitions): LogicalPlan = {
     val commonUdf = rel.getFunc
     val pythonUdf = transformPythonUDF(commonUdf)
+    val isBarrier = if (rel.hasIsBarrier) rel.getIsBarrier else false
     pythonUdf.evalType match {
       case PythonEvalType.SQL_MAP_PANDAS_ITER_UDF =>
         logical.MapInPandas(
           pythonUdf,
           pythonUdf.dataType.asInstanceOf[StructType].toAttributes,
           transformRelation(rel.getInput),
-          false)
+          isBarrier)
       case PythonEvalType.SQL_MAP_ARROW_ITER_UDF =>
         logical.PythonMapInArrow(
           pythonUdf,
           pythonUdf.dataType.asInstanceOf[StructType].toAttributes,
           transformRelation(rel.getInput),
-          false)
+          isBarrier)
       case _ =>
         throw InvalidPlanInput(s"Function with EvalType: ${pythonUdf.evalType} is not supported")
     }
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index 2dfc8e72193..10426c3c28d 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -1623,6 +1623,7 @@ class DataFrame:
         func: "PandasMapIterFunction",
         schema: Union[StructType, str],
         evalType: int,
+        is_barrier: bool,
     ) -> "DataFrame":
         from pyspark.sql.connect.udf import UserDefinedFunction
 
@@ -1636,21 +1637,31 @@ class DataFrame:
         )
 
         return DataFrame.withPlan(
-            plan.MapPartitions(child=self._plan, function=udf_obj, cols=self.columns),
+            plan.MapPartitions(
+                child=self._plan, function=udf_obj, cols=self.columns, is_barrier=is_barrier
+            ),
             session=self._session,
         )
 
     def mapInPandas(
-        self, func: "PandasMapIterFunction", schema: Union[StructType, str]
+        self,
+        func: "PandasMapIterFunction",
+        schema: Union[StructType, str],
+        is_barrier: bool = False,
     ) -> "DataFrame":
-        return self._map_partitions(func, schema, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF)
+        return self._map_partitions(
+            func, schema, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, is_barrier
+        )
 
     mapInPandas.__doc__ = PySparkDataFrame.mapInPandas.__doc__
 
     def mapInArrow(
-        self, func: "ArrowMapIterFunction", schema: Union[StructType, str]
+        self,
+        func: "ArrowMapIterFunction",
+        schema: Union[StructType, str],
+        is_barrier: bool = False,
     ) -> "DataFrame":
-        return self._map_partitions(func, schema, PythonEvalType.SQL_MAP_ARROW_ITER_UDF)
+        return self._map_partitions(func, schema, PythonEvalType.SQL_MAP_ARROW_ITER_UDF, is_barrier)
 
     mapInArrow.__doc__ = PySparkDataFrame.mapInArrow.__doc__
 
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index 7988cc33009..12a5879db0f 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1912,17 +1912,23 @@ class MapPartitions(LogicalPlan):
     """Logical plan object for a mapPartitions-equivalent API: mapInPandas, mapInArrow."""
 
     def __init__(
-        self, child: Optional["LogicalPlan"], function: "UserDefinedFunction", cols: List[str]
+        self,
+        child: Optional["LogicalPlan"],
+        function: "UserDefinedFunction",
+        cols: List[str],
+        is_barrier: bool,
     ) -> None:
         super().__init__(child)
 
         self._func = function._build_common_inline_user_defined_function(*cols)
+        self._is_barrier = is_barrier
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         assert self._child is not None
         plan = self._create_proto_relation()
         plan.map_partitions.input.CopyFrom(self._child.plan(session))
         plan.map_partitions.func.CopyFrom(self._func.to_plan_udf(session))
+        plan.map_partitions.is_barrier = self._is_barrier
         return plan
 
 
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py
index 81a0499666b..80e66fd4aae 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as spark_dot_connect_dot_catal
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xaf\x14\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+    b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xaf\x14\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
 )
 
 
@@ -828,17 +828,17 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _REPARTITIONBYEXPRESSION._serialized_start = 9782
     _REPARTITIONBYEXPRESSION._serialized_end = 9985
     _MAPPARTITIONS._serialized_start = 9988
-    _MAPPARTITIONS._serialized_end = 10118
-    _GROUPMAP._serialized_start = 10121
-    _GROUPMAP._serialized_end = 10324
-    _COGROUPMAP._serialized_start = 10327
-    _COGROUPMAP._serialized_end = 10679
-    _COLLECTMETRICS._serialized_start = 10682
-    _COLLECTMETRICS._serialized_end = 10818
-    _PARSE._serialized_start = 10821
-    _PARSE._serialized_end = 11209
+    _MAPPARTITIONS._serialized_end = 10169
+    _GROUPMAP._serialized_start = 10172
+    _GROUPMAP._serialized_end = 10375
+    _COGROUPMAP._serialized_start = 10378
+    _COGROUPMAP._serialized_end = 10730
+    _COLLECTMETRICS._serialized_start = 10733
+    _COLLECTMETRICS._serialized_end = 10869
+    _PARSE._serialized_start = 10872
+    _PARSE._serialized_end = 11260
     _PARSE_OPTIONSENTRY._serialized_start = 3293
     _PARSE_OPTIONSENTRY._serialized_end = 3351
-    _PARSE_PARSEFORMAT._serialized_start = 11110
-    _PARSE_PARSEFORMAT._serialized_end = 11198
+    _PARSE_PARSEFORMAT._serialized_start = 11161
+    _PARSE_PARSEFORMAT._serialized_end = 11249
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index f287a740346..c3cf733a995 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -2759,25 +2759,52 @@ class MapPartitions(google.protobuf.message.Message):
 
     INPUT_FIELD_NUMBER: builtins.int
     FUNC_FIELD_NUMBER: builtins.int
+    IS_BARRIER_FIELD_NUMBER: builtins.int
     @property
     def input(self) -> global___Relation:
         """(Required) Input relation for a mapPartitions-equivalent API: mapInPandas, mapInArrow."""
     @property
     def func(self) -> pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction:
         """(Required) Input user-defined function."""
+    is_barrier: builtins.bool
+    """(Optional) isBarrier."""
     def __init__(
         self,
         *,
         input: global___Relation | None = ...,
         func: pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction
         | None = ...,
+        is_barrier: builtins.bool | None = ...,
     ) -> None: ...
     def HasField(
-        self, field_name: typing_extensions.Literal["func", b"func", "input", b"input"]
+        self,
+        field_name: typing_extensions.Literal[
+            "_is_barrier",
+            b"_is_barrier",
+            "func",
+            b"func",
+            "input",
+            b"input",
+            "is_barrier",
+            b"is_barrier",
+        ],
     ) -> builtins.bool: ...
     def ClearField(
-        self, field_name: typing_extensions.Literal["func", b"func", "input", b"input"]
+        self,
+        field_name: typing_extensions.Literal[
+            "_is_barrier",
+            b"_is_barrier",
+            "func",
+            b"func",
+            "input",
+            b"input",
+            "is_barrier",
+            b"is_barrier",
+        ],
     ) -> None: ...
+    def WhichOneof(
+        self, oneof_group: typing_extensions.Literal["_is_barrier", b"_is_barrier"]
+    ) -> typing_extensions.Literal["is_barrier"] | None: ...
 
 global___MapPartitions = MapPartitions
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org