You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/11/08 01:00:39 UTC
[spark] branch master updated: [SPARK-41026][CONNECT] Support Repartition in Connect Proto
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 9695b2cb59b [SPARK-41026][CONNECT] Support Repartition in Connect Proto
9695b2cb59b is described below
commit 9695b2cb59b497709ca0050d754491d935742530
Author: Rui Wang <ru...@databricks.com>
AuthorDate: Tue Nov 8 09:00:12 2022 +0800
[SPARK-41026][CONNECT] Support Repartition in Connect Proto
### What changes were proposed in this pull request?
Support `Repartition` in Connect proto, which further supports two API: `repartition` (shuffle=true) and `coalesce` (shuffle=false).
### Why are the changes needed?
Improve API coverage.
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
UT
Closes #38529 from amaliujia/support_repartition_in_proto_connect.
Authored-by: Rui Wang <ru...@databricks.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../main/protobuf/spark/connect/relations.proto | 13 +++
.../org/apache/spark/sql/connect/dsl/package.scala | 18 ++++
.../sql/connect/planner/SparkConnectPlanner.scala | 5 +
.../connect/planner/SparkConnectProtoSuite.scala | 10 ++
python/pyspark/sql/connect/proto/relations_pb2.py | 114 +++++++++++----------
python/pyspark/sql/connect/proto/relations_pb2.pyi | 43 ++++++++
6 files changed, 147 insertions(+), 56 deletions(-)
diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto
index 8edd8911242..36113e2a30c 100644
--- a/connector/connect/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto
@@ -46,6 +46,7 @@ message Relation {
Deduplicate deduplicate = 14;
Range range = 15;
SubqueryAlias subquery_alias = 16;
+ Repartition repartition = 17;
Unknown unknown = 999;
}
@@ -241,3 +242,15 @@ message SubqueryAlias {
// Optional. Qualifier of the alias.
repeated string qualifier = 3;
}
+
+// Relation repartition.
+message Repartition {
+ // Required. The input relation.
+ Relation input = 1;
+
+ // Required. Must be positive.
+ int32 num_partitions = 2;
+
+ // Optional. Default value is false.
+ bool shuffle = 3;
+}
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index c40a9eed753..2755727de11 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -423,6 +423,24 @@ package object dsl {
byName))
.build()
+ def coalesce(num: Integer): Relation =
+ Relation
+ .newBuilder()
+ .setRepartition(
+ Repartition
+ .newBuilder()
+ .setInput(logicalPlan)
+ .setNumPartitions(num)
+ .setShuffle(false))
+ .build()
+
+ def repartition(num: Integer): Relation =
+ Relation
+ .newBuilder()
+ .setRepartition(
+ Repartition.newBuilder().setInput(logicalPlan).setNumPartitions(num).setShuffle(true))
+ .build()
+
private def createSetOperation(
left: Relation,
right: Relation,
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index d2b474711ab..1615fc56ab6 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -72,6 +72,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
case proto.Relation.RelTypeCase.RANGE => transformRange(rel.getRange)
case proto.Relation.RelTypeCase.SUBQUERY_ALIAS =>
transformSubqueryAlias(rel.getSubqueryAlias)
+ case proto.Relation.RelTypeCase.REPARTITION => transformRepartition(rel.getRepartition)
case proto.Relation.RelTypeCase.RELTYPE_NOT_SET =>
throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.")
case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.")
@@ -107,6 +108,10 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
transformRelation(rel.getInput))
}
+ private def transformRepartition(rel: proto.Repartition): LogicalPlan = {
+ logical.Repartition(rel.getNumPartitions, rel.getShuffle, transformRelation(rel.getInput))
+ }
+
private def transformRange(rel: proto.Range): LogicalPlan = {
val start = rel.getStart
val end = rel.getEnd
diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 0aa89d6f640..72dae674721 100644
--- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -251,6 +251,16 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
comparePlans(connect.sql("SELECT 1"), spark.sql("SELECT 1"))
}
+ test("Test Repartition") {
+ val connectPlan1 = connectTestRelation.repartition(12)
+ val sparkPlan1 = sparkTestRelation.repartition(12)
+ comparePlans(connectPlan1, sparkPlan1)
+
+ val connectPlan2 = connectTestRelation.coalesce(2)
+ val sparkPlan2 = sparkTestRelation.coalesce(2)
+ comparePlans(connectPlan2, sparkPlan2)
+ }
+
private def createLocalRelationProtoByQualifiedAttributes(
attrs: Seq[proto.Expression.QualifiedAttribute]): proto.Relation = {
val localRelationBuilder = proto.LocalRelation.newBuilder()
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py
index 6180c5e13c9..e43a5de583e 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -32,7 +32,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_e
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8c\x07\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...]
+ b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xcc\x07\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...]
)
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -44,59 +44,61 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_READ_DATASOURCE_OPTIONSENTRY._options = None
_READ_DATASOURCE_OPTIONSENTRY._serialized_options = b"8\001"
_RELATION._serialized_start = 82
- _RELATION._serialized_end = 990
- _UNKNOWN._serialized_start = 992
- _UNKNOWN._serialized_end = 1001
- _RELATIONCOMMON._serialized_start = 1003
- _RELATIONCOMMON._serialized_end = 1052
- _SQL._serialized_start = 1054
- _SQL._serialized_end = 1081
- _READ._serialized_start = 1084
- _READ._serialized_end = 1494
- _READ_NAMEDTABLE._serialized_start = 1226
- _READ_NAMEDTABLE._serialized_end = 1287
- _READ_DATASOURCE._serialized_start = 1290
- _READ_DATASOURCE._serialized_end = 1481
- _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 1423
- _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 1481
- _PROJECT._serialized_start = 1496
- _PROJECT._serialized_end = 1613
- _FILTER._serialized_start = 1615
- _FILTER._serialized_end = 1727
- _JOIN._serialized_start = 1730
- _JOIN._serialized_end = 2180
- _JOIN_JOINTYPE._serialized_start = 1993
- _JOIN_JOINTYPE._serialized_end = 2180
- _SETOPERATION._serialized_start = 2183
- _SETOPERATION._serialized_end = 2546
- _SETOPERATION_SETOPTYPE._serialized_start = 2432
- _SETOPERATION_SETOPTYPE._serialized_end = 2546
- _LIMIT._serialized_start = 2548
- _LIMIT._serialized_end = 2624
- _OFFSET._serialized_start = 2626
- _OFFSET._serialized_end = 2705
- _AGGREGATE._serialized_start = 2708
- _AGGREGATE._serialized_end = 2918
- _SORT._serialized_start = 2921
- _SORT._serialized_end = 3452
- _SORT_SORTFIELD._serialized_start = 3070
- _SORT_SORTFIELD._serialized_end = 3258
- _SORT_SORTDIRECTION._serialized_start = 3260
- _SORT_SORTDIRECTION._serialized_end = 3368
- _SORT_SORTNULLS._serialized_start = 3370
- _SORT_SORTNULLS._serialized_end = 3452
- _DEDUPLICATE._serialized_start = 3455
- _DEDUPLICATE._serialized_end = 3597
- _LOCALRELATION._serialized_start = 3599
- _LOCALRELATION._serialized_end = 3692
- _SAMPLE._serialized_start = 3695
- _SAMPLE._serialized_end = 3935
- _SAMPLE_SEED._serialized_start = 3909
- _SAMPLE_SEED._serialized_end = 3935
- _RANGE._serialized_start = 3938
- _RANGE._serialized_end = 4136
- _RANGE_NUMPARTITIONS._serialized_start = 4082
- _RANGE_NUMPARTITIONS._serialized_end = 4136
- _SUBQUERYALIAS._serialized_start = 4138
- _SUBQUERYALIAS._serialized_end = 4252
+ _RELATION._serialized_end = 1054
+ _UNKNOWN._serialized_start = 1056
+ _UNKNOWN._serialized_end = 1065
+ _RELATIONCOMMON._serialized_start = 1067
+ _RELATIONCOMMON._serialized_end = 1116
+ _SQL._serialized_start = 1118
+ _SQL._serialized_end = 1145
+ _READ._serialized_start = 1148
+ _READ._serialized_end = 1558
+ _READ_NAMEDTABLE._serialized_start = 1290
+ _READ_NAMEDTABLE._serialized_end = 1351
+ _READ_DATASOURCE._serialized_start = 1354
+ _READ_DATASOURCE._serialized_end = 1545
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 1487
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 1545
+ _PROJECT._serialized_start = 1560
+ _PROJECT._serialized_end = 1677
+ _FILTER._serialized_start = 1679
+ _FILTER._serialized_end = 1791
+ _JOIN._serialized_start = 1794
+ _JOIN._serialized_end = 2244
+ _JOIN_JOINTYPE._serialized_start = 2057
+ _JOIN_JOINTYPE._serialized_end = 2244
+ _SETOPERATION._serialized_start = 2247
+ _SETOPERATION._serialized_end = 2610
+ _SETOPERATION_SETOPTYPE._serialized_start = 2496
+ _SETOPERATION_SETOPTYPE._serialized_end = 2610
+ _LIMIT._serialized_start = 2612
+ _LIMIT._serialized_end = 2688
+ _OFFSET._serialized_start = 2690
+ _OFFSET._serialized_end = 2769
+ _AGGREGATE._serialized_start = 2772
+ _AGGREGATE._serialized_end = 2982
+ _SORT._serialized_start = 2985
+ _SORT._serialized_end = 3516
+ _SORT_SORTFIELD._serialized_start = 3134
+ _SORT_SORTFIELD._serialized_end = 3322
+ _SORT_SORTDIRECTION._serialized_start = 3324
+ _SORT_SORTDIRECTION._serialized_end = 3432
+ _SORT_SORTNULLS._serialized_start = 3434
+ _SORT_SORTNULLS._serialized_end = 3516
+ _DEDUPLICATE._serialized_start = 3519
+ _DEDUPLICATE._serialized_end = 3661
+ _LOCALRELATION._serialized_start = 3663
+ _LOCALRELATION._serialized_end = 3756
+ _SAMPLE._serialized_start = 3759
+ _SAMPLE._serialized_end = 3999
+ _SAMPLE_SEED._serialized_start = 3973
+ _SAMPLE_SEED._serialized_end = 3999
+ _RANGE._serialized_start = 4002
+ _RANGE._serialized_end = 4200
+ _RANGE_NUMPARTITIONS._serialized_start = 4146
+ _RANGE_NUMPARTITIONS._serialized_end = 4200
+ _SUBQUERYALIAS._serialized_start = 4202
+ _SUBQUERYALIAS._serialized_end = 4316
+ _REPARTITION._serialized_start = 4318
+ _REPARTITION._serialized_end = 4443
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index f5b5c9f90dc..30c1dddf885 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -75,6 +75,7 @@ class Relation(google.protobuf.message.Message):
DEDUPLICATE_FIELD_NUMBER: builtins.int
RANGE_FIELD_NUMBER: builtins.int
SUBQUERY_ALIAS_FIELD_NUMBER: builtins.int
+ REPARTITION_FIELD_NUMBER: builtins.int
UNKNOWN_FIELD_NUMBER: builtins.int
@property
def common(self) -> global___RelationCommon: ...
@@ -109,6 +110,8 @@ class Relation(google.protobuf.message.Message):
@property
def subquery_alias(self) -> global___SubqueryAlias: ...
@property
+ def repartition(self) -> global___Repartition: ...
+ @property
def unknown(self) -> global___Unknown: ...
def __init__(
self,
@@ -129,6 +132,7 @@ class Relation(google.protobuf.message.Message):
deduplicate: global___Deduplicate | None = ...,
range: global___Range | None = ...,
subquery_alias: global___SubqueryAlias | None = ...,
+ repartition: global___Repartition | None = ...,
unknown: global___Unknown | None = ...,
) -> None: ...
def HasField(
@@ -158,6 +162,8 @@ class Relation(google.protobuf.message.Message):
b"read",
"rel_type",
b"rel_type",
+ "repartition",
+ b"repartition",
"sample",
b"sample",
"set_op",
@@ -199,6 +205,8 @@ class Relation(google.protobuf.message.Message):
b"read",
"rel_type",
b"rel_type",
+ "repartition",
+ b"repartition",
"sample",
b"sample",
"set_op",
@@ -231,6 +239,7 @@ class Relation(google.protobuf.message.Message):
"deduplicate",
"range",
"subquery_alias",
+ "repartition",
"unknown",
] | None: ...
@@ -1022,3 +1031,37 @@ class SubqueryAlias(google.protobuf.message.Message):
) -> None: ...
global___SubqueryAlias = SubqueryAlias
+
+class Repartition(google.protobuf.message.Message):
+ """Relation repartition."""
+
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ INPUT_FIELD_NUMBER: builtins.int
+ NUM_PARTITIONS_FIELD_NUMBER: builtins.int
+ SHUFFLE_FIELD_NUMBER: builtins.int
+ @property
+ def input(self) -> global___Relation:
+ """Required. The input relation."""
+ num_partitions: builtins.int
+ """Required. Must be positive."""
+ shuffle: builtins.bool
+ """Optional. Default value is false."""
+ def __init__(
+ self,
+ *,
+ input: global___Relation | None = ...,
+ num_partitions: builtins.int = ...,
+ shuffle: builtins.bool = ...,
+ ) -> None: ...
+ def HasField(
+ self, field_name: typing_extensions.Literal["input", b"input"]
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "input", b"input", "num_partitions", b"num_partitions", "shuffle", b"shuffle"
+ ],
+ ) -> None: ...
+
+global___Repartition = Repartition
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org