You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2023/03/27 09:50:48 UTC
[spark] branch master updated: [SPARK-42929] make mapInPandas / mapInArrow support "is_barrier"
This is an automated email from the ASF dual-hosted git repository.
weichenxu123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 2a1ac07132b [SPARK-42929] make mapInPandas / mapInArrow support "is_barrier"
2a1ac07132b is described below
commit 2a1ac07132b7abc13e56b9a632b3dece7e4b60ea
Author: Weichen Xu <we...@databricks.com>
AuthorDate: Mon Mar 27 17:50:23 2023 +0800
[SPARK-42929] make mapInPandas / mapInArrow support "is_barrier"
### What changes were proposed in this pull request?
make mapInPandas / mapInArrow support "is_barrier"
### Why are the changes needed?
feature parity.
### Does this PR introduce _any_ user-facing change?
Yes.
### How was this patch tested?
Manually:
`bin/pyspark --remote local`:
```
from pyspark.sql.functions import pandas_udf
df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
def filter_func(iterator):
for pdf in iterator:
yield pdf[pdf.id == 1]
df.mapInPandas(filter_func, df.schema, is_barrier=True).collect()
def filter_func(iterator):
for batch in iterator:
pdf = batch.to_pandas()
yield pyarrow.RecordBatch.from_pandas(pdf[pdf.id == 1])
df.mapInArrow(filter_func, df.schema, is_barrier=True).collect()
```
Closes #40559 from WeichenXu123/spark-connect-barrier-mode.
Authored-by: Weichen Xu <we...@databricks.com>
Signed-off-by: Weichen Xu <we...@databricks.com>
---
.../main/protobuf/spark/connect/relations.proto | 3 +++
.../sql/connect/planner/SparkConnectPlanner.scala | 5 ++--
python/pyspark/sql/connect/dataframe.py | 21 +++++++++++----
python/pyspark/sql/connect/plan.py | 8 +++++-
python/pyspark/sql/connect/proto/relations_pb2.py | 24 ++++++++---------
python/pyspark/sql/connect/proto/relations_pb2.pyi | 31 ++++++++++++++++++++--
6 files changed, 70 insertions(+), 22 deletions(-)
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 976bd68e7fe..c965a6c8d32 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -794,6 +794,9 @@ message MapPartitions {
// (Required) Input user-defined function.
CommonInlineUserDefinedFunction func = 2;
+
+ // (Optional) isBarrier.
+ optional bool is_barrier = 3;
}
message GroupMap {
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index e7911ccdf11..e7e88cab643 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -484,19 +484,20 @@ class SparkConnectPlanner(val session: SparkSession) {
private def transformMapPartitions(rel: proto.MapPartitions): LogicalPlan = {
val commonUdf = rel.getFunc
val pythonUdf = transformPythonUDF(commonUdf)
+ val isBarrier = if (rel.hasIsBarrier) rel.getIsBarrier else false
pythonUdf.evalType match {
case PythonEvalType.SQL_MAP_PANDAS_ITER_UDF =>
logical.MapInPandas(
pythonUdf,
pythonUdf.dataType.asInstanceOf[StructType].toAttributes,
transformRelation(rel.getInput),
- false)
+ isBarrier)
case PythonEvalType.SQL_MAP_ARROW_ITER_UDF =>
logical.PythonMapInArrow(
pythonUdf,
pythonUdf.dataType.asInstanceOf[StructType].toAttributes,
transformRelation(rel.getInput),
- false)
+ isBarrier)
case _ =>
throw InvalidPlanInput(s"Function with EvalType: ${pythonUdf.evalType} is not supported")
}
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index 2dfc8e72193..10426c3c28d 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -1623,6 +1623,7 @@ class DataFrame:
func: "PandasMapIterFunction",
schema: Union[StructType, str],
evalType: int,
+ is_barrier: bool,
) -> "DataFrame":
from pyspark.sql.connect.udf import UserDefinedFunction
@@ -1636,21 +1637,31 @@ class DataFrame:
)
return DataFrame.withPlan(
- plan.MapPartitions(child=self._plan, function=udf_obj, cols=self.columns),
+ plan.MapPartitions(
+ child=self._plan, function=udf_obj, cols=self.columns, is_barrier=is_barrier
+ ),
session=self._session,
)
def mapInPandas(
- self, func: "PandasMapIterFunction", schema: Union[StructType, str]
+ self,
+ func: "PandasMapIterFunction",
+ schema: Union[StructType, str],
+ is_barrier: bool = False,
) -> "DataFrame":
- return self._map_partitions(func, schema, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF)
+ return self._map_partitions(
+ func, schema, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, is_barrier
+ )
mapInPandas.__doc__ = PySparkDataFrame.mapInPandas.__doc__
def mapInArrow(
- self, func: "ArrowMapIterFunction", schema: Union[StructType, str]
+ self,
+ func: "ArrowMapIterFunction",
+ schema: Union[StructType, str],
+ is_barrier: bool = False,
) -> "DataFrame":
- return self._map_partitions(func, schema, PythonEvalType.SQL_MAP_ARROW_ITER_UDF)
+ return self._map_partitions(func, schema, PythonEvalType.SQL_MAP_ARROW_ITER_UDF, is_barrier)
mapInArrow.__doc__ = PySparkDataFrame.mapInArrow.__doc__
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index 7988cc33009..12a5879db0f 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1912,17 +1912,23 @@ class MapPartitions(LogicalPlan):
"""Logical plan object for a mapPartitions-equivalent API: mapInPandas, mapInArrow."""
def __init__(
- self, child: Optional["LogicalPlan"], function: "UserDefinedFunction", cols: List[str]
+ self,
+ child: Optional["LogicalPlan"],
+ function: "UserDefinedFunction",
+ cols: List[str],
+ is_barrier: bool,
) -> None:
super().__init__(child)
self._func = function._build_common_inline_user_defined_function(*cols)
+ self._is_barrier = is_barrier
def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.map_partitions.input.CopyFrom(self._child.plan(session))
plan.map_partitions.func.CopyFrom(self._func.to_plan_udf(session))
+ plan.map_partitions.is_barrier = self._is_barrier
return plan
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py
index 81a0499666b..80e66fd4aae 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as spark_dot_connect_dot_catal
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xaf\x14\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+ b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xaf\x14\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
)
@@ -828,17 +828,17 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_REPARTITIONBYEXPRESSION._serialized_start = 9782
_REPARTITIONBYEXPRESSION._serialized_end = 9985
_MAPPARTITIONS._serialized_start = 9988
- _MAPPARTITIONS._serialized_end = 10118
- _GROUPMAP._serialized_start = 10121
- _GROUPMAP._serialized_end = 10324
- _COGROUPMAP._serialized_start = 10327
- _COGROUPMAP._serialized_end = 10679
- _COLLECTMETRICS._serialized_start = 10682
- _COLLECTMETRICS._serialized_end = 10818
- _PARSE._serialized_start = 10821
- _PARSE._serialized_end = 11209
+ _MAPPARTITIONS._serialized_end = 10169
+ _GROUPMAP._serialized_start = 10172
+ _GROUPMAP._serialized_end = 10375
+ _COGROUPMAP._serialized_start = 10378
+ _COGROUPMAP._serialized_end = 10730
+ _COLLECTMETRICS._serialized_start = 10733
+ _COLLECTMETRICS._serialized_end = 10869
+ _PARSE._serialized_start = 10872
+ _PARSE._serialized_end = 11260
_PARSE_OPTIONSENTRY._serialized_start = 3293
_PARSE_OPTIONSENTRY._serialized_end = 3351
- _PARSE_PARSEFORMAT._serialized_start = 11110
- _PARSE_PARSEFORMAT._serialized_end = 11198
+ _PARSE_PARSEFORMAT._serialized_start = 11161
+ _PARSE_PARSEFORMAT._serialized_end = 11249
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index f287a740346..c3cf733a995 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -2759,25 +2759,52 @@ class MapPartitions(google.protobuf.message.Message):
INPUT_FIELD_NUMBER: builtins.int
FUNC_FIELD_NUMBER: builtins.int
+ IS_BARRIER_FIELD_NUMBER: builtins.int
@property
def input(self) -> global___Relation:
"""(Required) Input relation for a mapPartitions-equivalent API: mapInPandas, mapInArrow."""
@property
def func(self) -> pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction:
"""(Required) Input user-defined function."""
+ is_barrier: builtins.bool
+ """(Optional) isBarrier."""
def __init__(
self,
*,
input: global___Relation | None = ...,
func: pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction
| None = ...,
+ is_barrier: builtins.bool | None = ...,
) -> None: ...
def HasField(
- self, field_name: typing_extensions.Literal["func", b"func", "input", b"input"]
+ self,
+ field_name: typing_extensions.Literal[
+ "_is_barrier",
+ b"_is_barrier",
+ "func",
+ b"func",
+ "input",
+ b"input",
+ "is_barrier",
+ b"is_barrier",
+ ],
) -> builtins.bool: ...
def ClearField(
- self, field_name: typing_extensions.Literal["func", b"func", "input", b"input"]
+ self,
+ field_name: typing_extensions.Literal[
+ "_is_barrier",
+ b"_is_barrier",
+ "func",
+ b"func",
+ "input",
+ b"input",
+ "is_barrier",
+ b"is_barrier",
+ ],
) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_is_barrier", b"_is_barrier"]
+ ) -> typing_extensions.Literal["is_barrier"] | None: ...
global___MapPartitions = MapPartitions
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org