You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2022/11/11 03:04:46 UTC
[spark] branch master updated: [SPARK-41105][CONNECT] Adopt `optional` keyword from proto3 which offers `hasXXX` to differentiate if a field is set or unset
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 46fab54b500 [SPARK-41105][CONNECT] Adopt `optional` keyword from proto3 which offers `hasXXX` to differentiate if a field is set or unset
46fab54b500 is described below
commit 46fab54b500c579cd421fb9e8ea95fae0ddda87d
Author: Rui Wang <ru...@databricks.com>
AuthorDate: Fri Nov 11 12:04:34 2022 +0900
[SPARK-41105][CONNECT] Adopt `optional` keyword from proto3 which offers `hasXXX` to differentiate if a field is set or unset
### What changes were proposed in this pull request?
We used to wrap those fields into messages to acquire the ability to tell if those field is set or unset. It turns out proto3 offers built-in mechanism to achieve the same thing: https://developers.google.com/protocol-buffers/docs/proto3#specifying_field_rules.
It is as easy as adding `optional` keyword to the field to auto-generate `hasXXX` method.
This PR refactors existing proto to get rid of redundant message definitions.
### Why are the changes needed?
Codebase simplification.
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
Existing UT
Closes #38606 from amaliujia/refactor_proto.
Authored-by: Rui Wang <ru...@databricks.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
.../main/protobuf/spark/connect/relations.proto | 12 +---
.../org/apache/spark/sql/connect/dsl/package.scala | 5 +-
.../sql/connect/planner/SparkConnectPlanner.scala | 4 +-
python/pyspark/sql/connect/plan.py | 6 +-
python/pyspark/sql/connect/proto/relations_pb2.py | 40 ++++++------
python/pyspark/sql/connect/proto/relations_pb2.pyi | 71 ++++++++++------------
.../sql/tests/connect/test_connect_plan_only.py | 4 +-
7 files changed, 61 insertions(+), 81 deletions(-)
diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto
index 639d1bafce5..4f30b5bfbde 100644
--- a/connector/connect/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto
@@ -215,11 +215,7 @@ message Sample {
double lower_bound = 2;
double upper_bound = 3;
bool with_replacement = 4;
- Seed seed = 5;
-
- message Seed {
- int64 seed = 1;
- }
+ optional int64 seed = 5;
}
// Relation of type [[Range]] that generates a sequence of integers.
@@ -232,11 +228,7 @@ message Range {
int64 step = 3;
// Optional. Default value is assigned by 1) SQL conf "spark.sql.leafNodeDefaultParallelism" if
// it is set, or 2) spark default parallelism.
- NumPartitions num_partitions = 4;
-
- message NumPartitions {
- int32 num_partitions = 1;
- }
+ optional int32 num_partitions = 4;
}
// Relation alias.
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 5e7a94da347..f55ed835d23 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -216,8 +216,7 @@ package object dsl {
range.setStep(1L)
}
if (numPartitions.isDefined) {
- range.setNumPartitions(
- proto.Range.NumPartitions.newBuilder().setNumPartitions(numPartitions.get))
+ range.setNumPartitions(numPartitions.get)
}
Relation.newBuilder().setRange(range).build()
}
@@ -376,7 +375,7 @@ package object dsl {
.setUpperBound(upperBound)
.setLowerBound(lowerBound)
.setWithReplacement(withReplacement)
- .setSeed(Sample.Seed.newBuilder().setSeed(seed).build())
+ .setSeed(seed)
.build())
.build()
}
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 04ce880a925..b91fef58a11 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -104,7 +104,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
rel.getLowerBound,
rel.getUpperBound,
rel.getWithReplacement,
- if (rel.hasSeed) rel.getSeed.getSeed else Utils.random.nextLong,
+ if (rel.hasSeed) rel.getSeed else Utils.random.nextLong,
transformRelation(rel.getInput))
}
@@ -117,7 +117,7 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
val end = rel.getEnd
val step = rel.getStep
val numPartitions = if (rel.hasNumPartitions) {
- rel.getNumPartitions.getNumPartitions
+ rel.getNumPartitions
} else {
session.leafNodeDefaultParallelism
}
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index e5eed195568..be1060a9fd8 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -443,7 +443,7 @@ class Sample(LogicalPlan):
plan.sample.upper_bound = self.upper_bound
plan.sample.with_replacement = self.with_replacement
if self.seed is not None:
- plan.sample.seed.seed = self.seed
+ plan.sample.seed = self.seed
return plan
def print(self, indent: int = 0) -> str:
@@ -777,9 +777,7 @@ class Range(LogicalPlan):
rel.range.end = self._end
rel.range.step = self._step
if self._num_partitions is not None:
- num_partitions_proto = rel.range.NumPartitions()
- num_partitions_proto.num_partitions = self._num_partitions
- rel.range.num_partitions.CopyFrom(num_partitions_proto)
+ rel.range.num_partitions = self._num_partitions
return rel
def print(self, indent: int = 0) -> str:
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py
index 323eb8e7690..73b789cf7d6 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -32,7 +32,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_e
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xb6\n\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\ [...]
+ b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xb6\n\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\ [...]
)
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -92,25 +92,21 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_LOCALRELATION._serialized_start = 4025
_LOCALRELATION._serialized_end = 4118
_SAMPLE._serialized_start = 4121
- _SAMPLE._serialized_end = 4361
- _SAMPLE_SEED._serialized_start = 4335
- _SAMPLE_SEED._serialized_end = 4361
- _RANGE._serialized_start = 4364
- _RANGE._serialized_end = 4562
- _RANGE_NUMPARTITIONS._serialized_start = 4508
- _RANGE_NUMPARTITIONS._serialized_end = 4562
- _SUBQUERYALIAS._serialized_start = 4564
- _SUBQUERYALIAS._serialized_end = 4678
- _REPARTITION._serialized_start = 4680
- _REPARTITION._serialized_end = 4805
- _STATSUMMARY._serialized_start = 4807
- _STATSUMMARY._serialized_end = 4899
- _STATCROSSTAB._serialized_start = 4901
- _STATCROSSTAB._serialized_end = 5002
- _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 5004
- _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 5118
- _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 5121
- _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 5380
- _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 5313
- _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 5380
+ _SAMPLE._serialized_end = 4319
+ _RANGE._serialized_start = 4322
+ _RANGE._serialized_end = 4452
+ _SUBQUERYALIAS._serialized_start = 4454
+ _SUBQUERYALIAS._serialized_end = 4568
+ _REPARTITION._serialized_start = 4570
+ _REPARTITION._serialized_end = 4695
+ _STATSUMMARY._serialized_start = 4697
+ _STATSUMMARY._serialized_end = 4789
+ _STATCROSSTAB._serialized_start = 4791
+ _STATCROSSTAB._serialized_end = 4892
+ _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 4894
+ _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 5008
+ _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 5011
+ _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 5270
+ _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 5203
+ _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 5270
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 53f75b7520f..e706fa3e11d 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -924,18 +924,6 @@ class Sample(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
- class Seed(google.protobuf.message.Message):
- DESCRIPTOR: google.protobuf.descriptor.Descriptor
-
- SEED_FIELD_NUMBER: builtins.int
- seed: builtins.int
- def __init__(
- self,
- *,
- seed: builtins.int = ...,
- ) -> None: ...
- def ClearField(self, field_name: typing_extensions.Literal["seed", b"seed"]) -> None: ...
-
INPUT_FIELD_NUMBER: builtins.int
LOWER_BOUND_FIELD_NUMBER: builtins.int
UPPER_BOUND_FIELD_NUMBER: builtins.int
@@ -946,8 +934,7 @@ class Sample(google.protobuf.message.Message):
lower_bound: builtins.float
upper_bound: builtins.float
with_replacement: builtins.bool
- @property
- def seed(self) -> global___Sample.Seed: ...
+ seed: builtins.int
def __init__(
self,
*,
@@ -955,14 +942,19 @@ class Sample(google.protobuf.message.Message):
lower_bound: builtins.float = ...,
upper_bound: builtins.float = ...,
with_replacement: builtins.bool = ...,
- seed: global___Sample.Seed | None = ...,
+ seed: builtins.int | None = ...,
) -> None: ...
def HasField(
- self, field_name: typing_extensions.Literal["input", b"input", "seed", b"seed"]
+ self,
+ field_name: typing_extensions.Literal[
+ "_seed", b"_seed", "input", b"input", "seed", b"seed"
+ ],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
+ "_seed",
+ b"_seed",
"input",
b"input",
"lower_bound",
@@ -975,6 +967,9 @@ class Sample(google.protobuf.message.Message):
b"with_replacement",
],
) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_seed", b"_seed"]
+ ) -> typing_extensions.Literal["seed"] | None: ...
global___Sample = Sample
@@ -983,20 +978,6 @@ class Range(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
- class NumPartitions(google.protobuf.message.Message):
- DESCRIPTOR: google.protobuf.descriptor.Descriptor
-
- NUM_PARTITIONS_FIELD_NUMBER: builtins.int
- num_partitions: builtins.int
- def __init__(
- self,
- *,
- num_partitions: builtins.int = ...,
- ) -> None: ...
- def ClearField(
- self, field_name: typing_extensions.Literal["num_partitions", b"num_partitions"]
- ) -> None: ...
-
START_FIELD_NUMBER: builtins.int
END_FIELD_NUMBER: builtins.int
STEP_FIELD_NUMBER: builtins.int
@@ -1007,28 +988,42 @@ class Range(google.protobuf.message.Message):
"""Required."""
step: builtins.int
"""Required."""
- @property
- def num_partitions(self) -> global___Range.NumPartitions:
- """Optional. Default value is assigned by 1) SQL conf "spark.sql.leafNodeDefaultParallelism" if
- it is set, or 2) spark default parallelism.
- """
+ num_partitions: builtins.int
+ """Optional. Default value is assigned by 1) SQL conf "spark.sql.leafNodeDefaultParallelism" if
+ it is set, or 2) spark default parallelism.
+ """
def __init__(
self,
*,
start: builtins.int = ...,
end: builtins.int = ...,
step: builtins.int = ...,
- num_partitions: global___Range.NumPartitions | None = ...,
+ num_partitions: builtins.int | None = ...,
) -> None: ...
def HasField(
- self, field_name: typing_extensions.Literal["num_partitions", b"num_partitions"]
+ self,
+ field_name: typing_extensions.Literal[
+ "_num_partitions", b"_num_partitions", "num_partitions", b"num_partitions"
+ ],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
- "end", b"end", "num_partitions", b"num_partitions", "start", b"start", "step", b"step"
+ "_num_partitions",
+ b"_num_partitions",
+ "end",
+ b"end",
+ "num_partitions",
+ b"num_partitions",
+ "start",
+ b"start",
+ "step",
+ b"step",
],
) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_num_partitions", b"_num_partitions"]
+ ) -> typing_extensions.Literal["num_partitions"] | None: ...
global___Range = Range
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
index c46d4d10624..4e26581a002 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py
@@ -121,7 +121,7 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
self.assertEqual(plan.root.sample.lower_bound, 0.0)
self.assertEqual(plan.root.sample.upper_bound, 0.4)
self.assertEqual(plan.root.sample.with_replacement, True)
- self.assertEqual(plan.root.sample.seed.seed, -1)
+ self.assertEqual(plan.root.sample.seed, -1)
def test_sort(self):
df = self.connect.readTable(table_name=self.tbl_name)
@@ -180,7 +180,7 @@ class SparkConnectTestsPlanOnly(PlanOnlyTestFixture):
self.assertEqual(plan.root.range.start, 10)
self.assertEqual(plan.root.range.end, 20)
self.assertEqual(plan.root.range.step, 3)
- self.assertEqual(plan.root.range.num_partitions.num_partitions, 4)
+ self.assertEqual(plan.root.range.num_partitions, 4)
plan = self.connect.range(start=10, end=20)._plan.to_proto(self.connect)
self.assertEqual(plan.root.range.start, 10)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org