You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/11/08 01:00:39 UTC

[spark] branch master updated: [SPARK-41026][CONNECT] Support Repartition in Connect Proto

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9695b2cb59b [SPARK-41026][CONNECT] Support Repartition in Connect Proto
9695b2cb59b is described below

commit 9695b2cb59b497709ca0050d754491d935742530
Author: Rui Wang <ru...@databricks.com>
AuthorDate: Tue Nov 8 09:00:12 2022 +0800

    [SPARK-41026][CONNECT] Support Repartition in Connect Proto
    
    ### What changes were proposed in this pull request?
    
    Support `Repartition` in Connect proto, which further supports two API: `repartition` (shuffle=true) and `coalesce` (shuffle=false).
    
    ### Why are the changes needed?
    
    Improve API coverage.
    
    ### Does this PR introduce _any_ user-facing change?
    
    NO
    
    ### How was this patch tested?
    
    UT
    
    Closes #38529 from amaliujia/support_repartition_in_proto_connect.
    
    Authored-by: Rui Wang <ru...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../main/protobuf/spark/connect/relations.proto    |  13 +++
 .../org/apache/spark/sql/connect/dsl/package.scala |  18 ++++
 .../sql/connect/planner/SparkConnectPlanner.scala  |   5 +
 .../connect/planner/SparkConnectProtoSuite.scala   |  10 ++
 python/pyspark/sql/connect/proto/relations_pb2.py  | 114 +++++++++++----------
 python/pyspark/sql/connect/proto/relations_pb2.pyi |  43 ++++++++
 6 files changed, 147 insertions(+), 56 deletions(-)

diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto
index 8edd8911242..36113e2a30c 100644
--- a/connector/connect/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto
@@ -46,6 +46,7 @@ message Relation {
     Deduplicate deduplicate = 14;
     Range range = 15;
     SubqueryAlias subquery_alias = 16;
+    Repartition repartition = 17;
 
     Unknown unknown = 999;
   }
@@ -241,3 +242,15 @@ message SubqueryAlias {
   // Optional. Qualifier of the alias.
   repeated string qualifier = 3;
 }
+
+// Relation repartition.
+message Repartition {
+  // Required. The input relation.
+  Relation input = 1;
+
+  // Required. Must be positive.
+  int32 num_partitions = 2;
+
+  // Optional. Default value is false.
+  bool shuffle = 3;
+}
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index c40a9eed753..2755727de11 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -423,6 +423,24 @@ package object dsl {
               byName))
           .build()
 
+      def coalesce(num: Integer): Relation =
+        Relation
+          .newBuilder()
+          .setRepartition(
+            Repartition
+              .newBuilder()
+              .setInput(logicalPlan)
+              .setNumPartitions(num)
+              .setShuffle(false))
+          .build()
+
+      def repartition(num: Integer): Relation =
+        Relation
+          .newBuilder()
+          .setRepartition(
+            Repartition.newBuilder().setInput(logicalPlan).setNumPartitions(num).setShuffle(true))
+          .build()
+
       private def createSetOperation(
           left: Relation,
           right: Relation,
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index d2b474711ab..1615fc56ab6 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -72,6 +72,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
       case proto.Relation.RelTypeCase.RANGE => transformRange(rel.getRange)
       case proto.Relation.RelTypeCase.SUBQUERY_ALIAS =>
         transformSubqueryAlias(rel.getSubqueryAlias)
+      case proto.Relation.RelTypeCase.REPARTITION => transformRepartition(rel.getRepartition)
       case proto.Relation.RelTypeCase.RELTYPE_NOT_SET =>
         throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.")
       case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.")
@@ -107,6 +108,10 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
       transformRelation(rel.getInput))
   }
 
+  private def transformRepartition(rel: proto.Repartition): LogicalPlan = {
+    logical.Repartition(rel.getNumPartitions, rel.getShuffle, transformRelation(rel.getInput))
+  }
+
   private def transformRange(rel: proto.Range): LogicalPlan = {
     val start = rel.getStart
     val end = rel.getEnd
diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 0aa89d6f640..72dae674721 100644
--- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -251,6 +251,16 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
     comparePlans(connect.sql("SELECT 1"), spark.sql("SELECT 1"))
   }
 
+  test("Test Repartition") {
+    val connectPlan1 = connectTestRelation.repartition(12)
+    val sparkPlan1 = sparkTestRelation.repartition(12)
+    comparePlans(connectPlan1, sparkPlan1)
+
+    val connectPlan2 = connectTestRelation.coalesce(2)
+    val sparkPlan2 = sparkTestRelation.coalesce(2)
+    comparePlans(connectPlan2, sparkPlan2)
+  }
+
   private def createLocalRelationProtoByQualifiedAttributes(
       attrs: Seq[proto.Expression.QualifiedAttribute]): proto.Relation = {
     val localRelationBuilder = proto.LocalRelation.newBuilder()
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py
index 6180c5e13c9..e43a5de583e 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -32,7 +32,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_e
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x8c\x07\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...]
+    b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xcc\x07\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -44,59 +44,61 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _READ_DATASOURCE_OPTIONSENTRY._options = None
     _READ_DATASOURCE_OPTIONSENTRY._serialized_options = b"8\001"
     _RELATION._serialized_start = 82
-    _RELATION._serialized_end = 990
-    _UNKNOWN._serialized_start = 992
-    _UNKNOWN._serialized_end = 1001
-    _RELATIONCOMMON._serialized_start = 1003
-    _RELATIONCOMMON._serialized_end = 1052
-    _SQL._serialized_start = 1054
-    _SQL._serialized_end = 1081
-    _READ._serialized_start = 1084
-    _READ._serialized_end = 1494
-    _READ_NAMEDTABLE._serialized_start = 1226
-    _READ_NAMEDTABLE._serialized_end = 1287
-    _READ_DATASOURCE._serialized_start = 1290
-    _READ_DATASOURCE._serialized_end = 1481
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 1423
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 1481
-    _PROJECT._serialized_start = 1496
-    _PROJECT._serialized_end = 1613
-    _FILTER._serialized_start = 1615
-    _FILTER._serialized_end = 1727
-    _JOIN._serialized_start = 1730
-    _JOIN._serialized_end = 2180
-    _JOIN_JOINTYPE._serialized_start = 1993
-    _JOIN_JOINTYPE._serialized_end = 2180
-    _SETOPERATION._serialized_start = 2183
-    _SETOPERATION._serialized_end = 2546
-    _SETOPERATION_SETOPTYPE._serialized_start = 2432
-    _SETOPERATION_SETOPTYPE._serialized_end = 2546
-    _LIMIT._serialized_start = 2548
-    _LIMIT._serialized_end = 2624
-    _OFFSET._serialized_start = 2626
-    _OFFSET._serialized_end = 2705
-    _AGGREGATE._serialized_start = 2708
-    _AGGREGATE._serialized_end = 2918
-    _SORT._serialized_start = 2921
-    _SORT._serialized_end = 3452
-    _SORT_SORTFIELD._serialized_start = 3070
-    _SORT_SORTFIELD._serialized_end = 3258
-    _SORT_SORTDIRECTION._serialized_start = 3260
-    _SORT_SORTDIRECTION._serialized_end = 3368
-    _SORT_SORTNULLS._serialized_start = 3370
-    _SORT_SORTNULLS._serialized_end = 3452
-    _DEDUPLICATE._serialized_start = 3455
-    _DEDUPLICATE._serialized_end = 3597
-    _LOCALRELATION._serialized_start = 3599
-    _LOCALRELATION._serialized_end = 3692
-    _SAMPLE._serialized_start = 3695
-    _SAMPLE._serialized_end = 3935
-    _SAMPLE_SEED._serialized_start = 3909
-    _SAMPLE_SEED._serialized_end = 3935
-    _RANGE._serialized_start = 3938
-    _RANGE._serialized_end = 4136
-    _RANGE_NUMPARTITIONS._serialized_start = 4082
-    _RANGE_NUMPARTITIONS._serialized_end = 4136
-    _SUBQUERYALIAS._serialized_start = 4138
-    _SUBQUERYALIAS._serialized_end = 4252
+    _RELATION._serialized_end = 1054
+    _UNKNOWN._serialized_start = 1056
+    _UNKNOWN._serialized_end = 1065
+    _RELATIONCOMMON._serialized_start = 1067
+    _RELATIONCOMMON._serialized_end = 1116
+    _SQL._serialized_start = 1118
+    _SQL._serialized_end = 1145
+    _READ._serialized_start = 1148
+    _READ._serialized_end = 1558
+    _READ_NAMEDTABLE._serialized_start = 1290
+    _READ_NAMEDTABLE._serialized_end = 1351
+    _READ_DATASOURCE._serialized_start = 1354
+    _READ_DATASOURCE._serialized_end = 1545
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 1487
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 1545
+    _PROJECT._serialized_start = 1560
+    _PROJECT._serialized_end = 1677
+    _FILTER._serialized_start = 1679
+    _FILTER._serialized_end = 1791
+    _JOIN._serialized_start = 1794
+    _JOIN._serialized_end = 2244
+    _JOIN_JOINTYPE._serialized_start = 2057
+    _JOIN_JOINTYPE._serialized_end = 2244
+    _SETOPERATION._serialized_start = 2247
+    _SETOPERATION._serialized_end = 2610
+    _SETOPERATION_SETOPTYPE._serialized_start = 2496
+    _SETOPERATION_SETOPTYPE._serialized_end = 2610
+    _LIMIT._serialized_start = 2612
+    _LIMIT._serialized_end = 2688
+    _OFFSET._serialized_start = 2690
+    _OFFSET._serialized_end = 2769
+    _AGGREGATE._serialized_start = 2772
+    _AGGREGATE._serialized_end = 2982
+    _SORT._serialized_start = 2985
+    _SORT._serialized_end = 3516
+    _SORT_SORTFIELD._serialized_start = 3134
+    _SORT_SORTFIELD._serialized_end = 3322
+    _SORT_SORTDIRECTION._serialized_start = 3324
+    _SORT_SORTDIRECTION._serialized_end = 3432
+    _SORT_SORTNULLS._serialized_start = 3434
+    _SORT_SORTNULLS._serialized_end = 3516
+    _DEDUPLICATE._serialized_start = 3519
+    _DEDUPLICATE._serialized_end = 3661
+    _LOCALRELATION._serialized_start = 3663
+    _LOCALRELATION._serialized_end = 3756
+    _SAMPLE._serialized_start = 3759
+    _SAMPLE._serialized_end = 3999
+    _SAMPLE_SEED._serialized_start = 3973
+    _SAMPLE_SEED._serialized_end = 3999
+    _RANGE._serialized_start = 4002
+    _RANGE._serialized_end = 4200
+    _RANGE_NUMPARTITIONS._serialized_start = 4146
+    _RANGE_NUMPARTITIONS._serialized_end = 4200
+    _SUBQUERYALIAS._serialized_start = 4202
+    _SUBQUERYALIAS._serialized_end = 4316
+    _REPARTITION._serialized_start = 4318
+    _REPARTITION._serialized_end = 4443
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index f5b5c9f90dc..30c1dddf885 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -75,6 +75,7 @@ class Relation(google.protobuf.message.Message):
     DEDUPLICATE_FIELD_NUMBER: builtins.int
     RANGE_FIELD_NUMBER: builtins.int
     SUBQUERY_ALIAS_FIELD_NUMBER: builtins.int
+    REPARTITION_FIELD_NUMBER: builtins.int
     UNKNOWN_FIELD_NUMBER: builtins.int
     @property
     def common(self) -> global___RelationCommon: ...
@@ -109,6 +110,8 @@ class Relation(google.protobuf.message.Message):
     @property
     def subquery_alias(self) -> global___SubqueryAlias: ...
     @property
+    def repartition(self) -> global___Repartition: ...
+    @property
     def unknown(self) -> global___Unknown: ...
     def __init__(
         self,
@@ -129,6 +132,7 @@ class Relation(google.protobuf.message.Message):
         deduplicate: global___Deduplicate | None = ...,
         range: global___Range | None = ...,
         subquery_alias: global___SubqueryAlias | None = ...,
+        repartition: global___Repartition | None = ...,
         unknown: global___Unknown | None = ...,
     ) -> None: ...
     def HasField(
@@ -158,6 +162,8 @@ class Relation(google.protobuf.message.Message):
             b"read",
             "rel_type",
             b"rel_type",
+            "repartition",
+            b"repartition",
             "sample",
             b"sample",
             "set_op",
@@ -199,6 +205,8 @@ class Relation(google.protobuf.message.Message):
             b"read",
             "rel_type",
             b"rel_type",
+            "repartition",
+            b"repartition",
             "sample",
             b"sample",
             "set_op",
@@ -231,6 +239,7 @@ class Relation(google.protobuf.message.Message):
         "deduplicate",
         "range",
         "subquery_alias",
+        "repartition",
         "unknown",
     ] | None: ...
 
@@ -1022,3 +1031,37 @@ class SubqueryAlias(google.protobuf.message.Message):
     ) -> None: ...
 
 global___SubqueryAlias = SubqueryAlias
+
+class Repartition(google.protobuf.message.Message):
+    """Relation repartition."""
+
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    INPUT_FIELD_NUMBER: builtins.int
+    NUM_PARTITIONS_FIELD_NUMBER: builtins.int
+    SHUFFLE_FIELD_NUMBER: builtins.int
+    @property
+    def input(self) -> global___Relation:
+        """Required. The input relation."""
+    num_partitions: builtins.int
+    """Required. Must be positive."""
+    shuffle: builtins.bool
+    """Optional. Default value is false."""
+    def __init__(
+        self,
+        *,
+        input: global___Relation | None = ...,
+        num_partitions: builtins.int = ...,
+        shuffle: builtins.bool = ...,
+    ) -> None: ...
+    def HasField(
+        self, field_name: typing_extensions.Literal["input", b"input"]
+    ) -> builtins.bool: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal[
+            "input", b"input", "num_partitions", b"num_partitions", "shuffle", b"shuffle"
+        ],
+    ) -> None: ...
+
+global___Repartition = Repartition


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org