You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2022/11/11 03:04:46 UTC

[spark] branch master updated: [SPARK-41105][CONNECT] Adopt `optional` keyword from proto3 which offers `hasXXX` to differentiate if a field is set or unset

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 46fab54b500 [SPARK-41105][CONNECT] Adopt `optional` keyword from proto3 which offers `hasXXX` to differentiate if a field is set or unset
46fab54b500 is described below

commit 46fab54b500c579cd421fb9e8ea95fae0ddda87d
Author: Rui Wang <ru...@databricks.com>
AuthorDate: Fri Nov 11 12:04:34 2022 +0900

    [SPARK-41105][CONNECT] Adopt `optional` keyword from proto3 which offers `hasXXX` to differentiate if a field is set or unset
    
    ### What changes were proposed in this pull request?
    
    We used to wrap those fields into messages to acquire the ability to tell if those field is set or unset. It turns out proto3 offers built-in mechanism to achieve the same thing: https://developers.google.com/protocol-buffers/docs/proto3#specifying_field_rules.
    
    It is as easy as adding `optional` keyword to the field to auto-generate `hasXXX` method.
    
    This PR refactors existing proto to get rid of redundant message definitions.
    
    ### Why are the changes needed?
    
    Codebase simplification.
    
    ### Does this PR introduce _any_ user-facing change?
    
    NO
    
    ### How was this patch tested?
    
    Existing UT
    
    Closes #38606 from amaliujia/refactor_proto.
    
    Authored-by: Rui Wang <ru...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 .../main/protobuf/spark/connect/relations.proto    | 12 +---
 .../org/apache/spark/sql/connect/dsl/package.scala |  5 +-
 .../sql/connect/planner/SparkConnectPlanner.scala  |  4 +-
 python/pyspark/sql/connect/plan.py                 |  6 +-
 python/pyspark/sql/connect/proto/relations_pb2.py  | 40 ++++++------
 python/pyspark/sql/connect/proto/relations_pb2.pyi | 71 ++++++++++------------
 .../sql/tests/connect/test_connect_plan_only.py    |  4 +-
 7 files changed, 61 insertions(+), 81 deletions(-)

diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto
index 639d1bafce5..4f30b5bfbde 100644
--- a/connector/connect/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto
@@ -215,11 +215,7 @@ message Sample {
   double lower_bound = 2;
   double upper_bound = 3;
   bool with_replacement = 4;
-  Seed seed = 5;
-
-  message Seed {
-    int64 seed = 1;
-  }
+  optional int64 seed = 5;
 }
 
 // Relation of type [[Range]] that generates a sequence of integers.
@@ -232,11 +228,7 @@ message Range {
   int64 step = 3;
   // Optional. Default value is assigned by 1) SQL conf "spark.sql.leafNodeDefaultParallelism" if
   // it is set, or 2) spark default parallelism.
-  NumPartitions num_partitions = 4;
-
-  message NumPartitions {
-    int32 num_partitions = 1;
-  }
+  optional int32 num_partitions = 4;
 }
 
 // Relation alias.
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 5e7a94da347..f55ed835d23 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -216,8 +216,7 @@ package object dsl {
           range.setStep(1L)
         }
         if (numPartitions.isDefined) {
-          range.setNumPartitions(
-            proto.Range.NumPartitions.newBuilder().setNumPartitions(numPartitions.get))
+          range.setNumPartitions(numPartitions.get)
         }
         Relation.newBuilder().setRange(range).build()
       }
@@ -376,7 +375,7 @@ package object dsl {
               .setUpperBound(upperBound)
               .setLowerBound(lowerBound)
               .setWithReplacement(withReplacement)
-              .setSeed(Sample.Seed.newBuilder().setSeed(seed).build())
+              .setSeed(seed)
               .build())
           .build()
       }
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 04ce880a925..b91fef58a11 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -104,7 +104,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
       rel.getLowerBound,
       rel.getUpperBound,
       rel.getWithReplacement,
-      if (rel.hasSeed) rel.getSeed.getSeed else Utils.random.nextLong,
+      if (rel.hasSeed) rel.getSeed else Utils.random.nextLong,
       transformRelation(rel.getInput))
   }
 
@@ -117,7 +117,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
     val end = rel.getEnd
     val step = rel.getStep
     val numPartitions = if (rel.hasNumPartitions) {
-      rel.getNumPartitions.getNumPartitions
+      rel.getNumPartitions
     } else {
       session.leafNodeDefaultParallelism
     }
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index e5eed195568..be1060a9fd8 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -443,7 +443,7 @@ class Sample(LogicalPlan):
         plan.sample.upper_bound = self.upper_bound
         plan.sample.with_replacement = self.with_replacement
         if self.seed is not None:
-            plan.sample.seed.seed = self.seed
+            plan.sample.seed = self.seed
         return plan
 
     def print(self, indent: int = 0) -> str:
@@ -777,9 +777,7 @@ class Range(LogicalPlan):
         rel.range.end = self._end
         rel.range.step = self._step
         if self._num_partitions is not None:
-            num_partitions_proto = rel.range.NumPartitions()
-            num_partitions_proto.num_partitions = self._num_partitions
-            rel.range.num_partitions.CopyFrom(num_partitions_proto)
+            rel.range.num_partitions = self._num_partitions
         return rel
 
     def print(self, indent: int = 0) -> str:
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py
index 323eb8e7690..73b789cf7d6 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -32,7 +32,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_e
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xb6\n\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\ [...]
+    b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xb6\n\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\ [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -92,25 +92,21 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _LOCALRELATION._serialized_start = 4025
     _LOCALRELATION._serialized_end = 4118
     _SAMPLE._serialized_start = 4121
-    _SAMPLE._serialized_end = 4361
-    _SAMPLE_SEED._serialized_start = 4335
-    _SAMPLE_SEED._serialized_end = 4361
-    _RANGE._serialized_start = 4364
-    _RANGE._serialized_end = 4562
-    _RANGE_NUMPARTITIONS._serialized_start = 4508
-    _RANGE_NUMPARTITIONS._serialized_end = 4562
-    _SUBQUERYALIAS._serialized_start = 4564
-    _SUBQUERYALIAS._serialized_end = 4678
-    _REPARTITION._serialized_start = 4680
-    _REPARTITION._serialized_end = 4805
-    _STATSUMMARY._serialized_start = 4807
-    _STATSUMMARY._serialized_end = 4899
-    _STATCROSSTAB._serialized_start = 4901
-    _STATCROSSTAB._serialized_end = 5002
-    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 5004
-    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 5118
-    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 5121
-    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 5380
-    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 5313
-    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 5380
+    _SAMPLE._serialized_end = 4319
+    _RANGE._serialized_start = 4322
+    _RANGE._serialized_end = 4452
+    _SUBQUERYALIAS._serialized_start = 4454
+    _SUBQUERYALIAS._serialized_end = 4568
+    _REPARTITION._serialized_start = 4570
+    _REPARTITION._serialized_end = 4695
+    _STATSUMMARY._serialized_start = 4697
+    _STATSUMMARY._serialized_end = 4789
+    _STATCROSSTAB._serialized_start = 4791
+    _STATCROSSTAB._serialized_end = 4892
+    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 4894
+    _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 5008
+    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 5011
+    _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 5270
+    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 5203
+    _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 5270
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 53f75b7520f..e706fa3e11d 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -924,18 +924,6 @@ class Sample(google.protobuf.message.Message):
 
     DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
-    class Seed(google.protobuf.message.Message):
-        DESCRIPTOR: google.protobuf.descriptor.Descriptor
-
-        SEED_FIELD_NUMBER: builtins.int
-        seed: builtins.int
-        def __init__(
-            self,
-            *,
-            seed: builtins.int = ...,
-        ) -> None: ...
-        def ClearField(self, field_name: typing_extensions.Literal["seed", b"seed"]) -> None: ...
-
     INPUT_FIELD_NUMBER: builtins.int
     LOWER_BOUND_FIELD_NUMBER: builtins.int
     UPPER_BOUND_FIELD_NUMBER: builtins.int
@@ -946,8 +934,7 @@ class Sample(google.protobuf.message.Message):
     lower_bound: builtins.float
     upper_bound: builtins.float
     with_replacement: builtins.bool
-    @property
-    def seed(self) -> global___Sample.Seed: ...
+    seed: builtins.int
     def __init__(
         self,
         *,
@@ -955,14 +942,19 @@ class Sample(google.protobuf.message.Message):
         lower_bound: builtins.float = ...,
         upper_bound: builtins.float = ...,
         with_replacement: builtins.bool = ...,
-        seed: global___Sample.Seed | None = ...,
+        seed: builtins.int | None = ...,
     ) -> None: ...
     def HasField(
-        self, field_name: typing_extensions.Literal["input", b"input", "seed", b"seed"]
+        self,
+        field_name: typing_extensions.Literal[
+            "_seed", b"_seed", "input", b"input", "seed", b"seed"
+        ],
     ) -> builtins.bool: ...
     def ClearField(
         self,
         field_name: typing_extensions.Literal[
+            "_seed",
+            b"_seed",
             "input",
             b"input",
             "lower_bound",
@@ -975,6 +967,9 @@ class Sample(google.protobuf.message.Message):
             b"with_replacement",
         ],
     ) -> None: ...
+    def WhichOneof(
+        self, oneof_group: typing_extensions.Literal["_seed", b"_seed"]
+    ) -> typing_extensions.Literal["seed"] | None: ...
 
 global___Sample = Sample
 
@@ -983,20 +978,6 @@ class Range(google.protobuf.message.Message):
 
     DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
-    class NumPartitions(google.protobuf.message.Message):
-        DESCRIPTOR: google.protobuf.descriptor.Descriptor
-
-        NUM_PARTITIONS_FIELD_NUMBER: builtins.int
-        num_partitions: builtins.int
-        def __init__(
-            self,
-            *,
-            num_partitions: builtins.int = ...,
-        ) -> None: ...
-        def ClearField(
-            self, field_name: typing_extensions.Literal["num_partitions", b"num_partitions"]
-        ) -> None: ...
-
     START_FIELD_NUMBER: builtins.int
     END_FIELD_NUMBER: builtins.int
     STEP_FIELD_NUMBER: builtins.int
@@ -1007,28 +988,42 @@ class Range(google.protobuf.message.Message):
     """Required."""
     step: builtins.int
     """Required."""
-    @property
-    def num_partitions(self) -> global___Range.NumPartitions:
-        """Optional. Default value is assigned by 1) SQL conf "spark.sql.leafNodeDefaultParallelism" if
-        it is set, or 2) spark default parallelism.
-        """
+    num_partitions: builtins.int
+    """Optional. Default value is assigned by 1) SQL conf "spark.sql.leafNodeDefaultParallelism" if
+    it is set, or 2) spark default parallelism.
+    """
     def __init__(
         self,
         *,
         start: builtins.int = ...,
         end: builtins.int = ...,
         step: builtins.int = ...,
-        num_partitions: global___Range.NumPartitions | None = ...,
+        num_partitions: builtins.int | None = ...,
     ) -> None: ...
     def HasField(
-        self, field_name: typing_extensions.Literal["num_partitions", b"num_partitions"]
+        self,
+        field_name: typing_extensions.Literal[
+            "_num_partitions", b"_num_partitions", "num_partitions", b"num_partitions"
+        ],
     ) -> builtins.bool: ...
     def ClearField(
         self,
         field_name: typing_extensions.Literal[
-            "end", b"end", "num_partitions", b"num_partitions", "start", b"start", "step", b"step"
+            "_num_partitions",
+            b"_num_partitions",
+            "end",
+            b"end",
+            "num_partitions",
+            b"num_partitions",
+            "start",
+            b"start",
+            "step",
+            b"step",
         ],
     ) -> None: ...
+    def WhichOneof(
+        self, oneof_group: typing_extensions.Literal["_num_partitions", b"_num_partitions"]
+    ) -> typing_extensions.Literal["num_partitions"] | None: ...
 
 global___Range = Range
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
index c46d4d10624..4e26581a002 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
@@ -121,7 +121,7 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
         self.assertEqual(plan.root.sample.lower_bound, 0.0)
         self.assertEqual(plan.root.sample.upper_bound, 0.4)
         self.assertEqual(plan.root.sample.with_replacement, True)
-        self.assertEqual(plan.root.sample.seed.seed, -1)
+        self.assertEqual(plan.root.sample.seed, -1)
 
     def test_sort(self):
         df = self.connect.readTable(table_name=self.tbl_name)
@@ -180,7 +180,7 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
         self.assertEqual(plan.root.range.start, 10)
         self.assertEqual(plan.root.range.end, 20)
         self.assertEqual(plan.root.range.step, 3)
-        self.assertEqual(plan.root.range.num_partitions.num_partitions, 4)
+        self.assertEqual(plan.root.range.num_partitions, 4)
 
         plan = self.connect.range(start=10, end=20)._plan.to_proto(self.connect)
         self.assertEqual(plan.root.range.start, 10)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org