You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/11/09 07:48:44 UTC
[spark] branch master updated: [SPARK-40992][CONNECT] Support toDF(columnNames) in Connect DSL
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e1382c566b7 [SPARK-40992][CONNECT] Support toDF(columnNames) in Connect DSL
e1382c566b7 is described below
commit e1382c566b7b2ba324fec1aed6556325ebe43f7b
Author: Rui Wang <ru...@databricks.com>
AuthorDate: Wed Nov 9 15:48:24 2022 +0800
[SPARK-40992][CONNECT] Support toDF(columnNames) in Connect DSL
### What changes were proposed in this pull request?
Add `RenameColumns` to proto to support the implementation for `toDF(columnNames: String*)` which renames the input relation to a different set of column names.
### Why are the changes needed?
Improve API coverage.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
UT
Closes #38475 from amaliujia/SPARK-40992.
Authored-by: Rui Wang <ru...@databricks.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../main/protobuf/spark/connect/relations.proto | 12 ++
.../org/apache/spark/sql/connect/dsl/package.scala | 10 ++
.../sql/connect/planner/SparkConnectPlanner.scala | 9 ++
.../connect/planner/SparkConnectProtoSuite.scala | 4 +
python/pyspark/sql/connect/proto/relations_pb2.py | 126 +++++++++++----------
python/pyspark/sql/connect/proto/relations_pb2.pyi | 44 +++++++
6 files changed, 143 insertions(+), 62 deletions(-)
diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto
index dd03bd86940..cce9f3b939e 100644
--- a/connector/connect/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto
@@ -47,6 +47,7 @@ message Relation {
Range range = 15;
SubqueryAlias subquery_alias = 16;
Repartition repartition = 17;
+ RenameColumns rename_columns = 18;
StatFunction stat_function = 100;
@@ -274,3 +275,14 @@ message StatFunction {
}
}
+// Rename columns on the input relation.
+message RenameColumns {
+ // Required. The input relation.
+ Relation input = 1;
+
+ // Required.
+ //
+ // The number of columns of the input relation must be equal to the length
+ // of this field. If this is not true, an exception will be returned.
+ repeated string column_names = 2;
+}
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 3e68b101057..d6f7a6756c3 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -457,6 +457,16 @@ package object dsl {
.build()
}
+ def toDF(columnNames: String*): Relation =
+ Relation
+ .newBuilder()
+ .setRenameColumns(
+ RenameColumns
+ .newBuilder()
+ .setInput(logicalPlan)
+ .addAllColumnNames(columnNames.asJava))
+ .build()
+
private def createSetOperation(
left: Relation,
right: Relation,
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 3bbdbf80276..87716c702b5 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -69,6 +69,8 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
case proto.Relation.RelTypeCase.REPARTITION => transformRepartition(rel.getRepartition)
case proto.Relation.RelTypeCase.STAT_FUNCTION =>
transformStatFunction(rel.getStatFunction)
+ case proto.Relation.RelTypeCase.RENAME_COLUMNS =>
+ transformRenameColumns(rel.getRenameColumns)
case proto.Relation.RelTypeCase.RELTYPE_NOT_SET =>
throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.")
case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.")
@@ -133,6 +135,13 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
}
}
+ private def transformRenameColumns(rel: proto.RenameColumns): LogicalPlan = {
+ Dataset
+ .ofRows(session, transformRelation(rel.getInput))
+ .toDF(rel.getColumnNamesList.asScala.toSeq: _*)
+ .logicalPlan
+ }
+
private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = {
if (!rel.hasInput) {
throw InvalidPlanInput("Deduplicate needs a plan input")
diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index c5b6f4fc0ee..2339c676a38 100644
--- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -267,6 +267,10 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
sparkTestRelation.summary("count", "mean", "stddev"))
}
+ test("Test toDF") {
+ comparePlans(connectTestRelation.toDF("col1", "col2"), sparkTestRelation.toDF("col1", "col2"))
+ }
+
private def createLocalRelationProtoByQualifiedAttributes(
attrs: Seq[proto.Expression.QualifiedAttribute]): proto.Relation = {
val localRelationBuilder = proto.LocalRelation.newBuilder()
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py
index b11a4b0e91a..06b59ea5f45 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -32,7 +32,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_e
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x90\x08\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...]
+ b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xd7\x08\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...]
)
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -44,65 +44,67 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_READ_DATASOURCE_OPTIONSENTRY._options = None
_READ_DATASOURCE_OPTIONSENTRY._serialized_options = b"8\001"
_RELATION._serialized_start = 82
- _RELATION._serialized_end = 1122
- _UNKNOWN._serialized_start = 1124
- _UNKNOWN._serialized_end = 1133
- _RELATIONCOMMON._serialized_start = 1135
- _RELATIONCOMMON._serialized_end = 1184
- _SQL._serialized_start = 1186
- _SQL._serialized_end = 1213
- _READ._serialized_start = 1216
- _READ._serialized_end = 1626
- _READ_NAMEDTABLE._serialized_start = 1358
- _READ_NAMEDTABLE._serialized_end = 1419
- _READ_DATASOURCE._serialized_start = 1422
- _READ_DATASOURCE._serialized_end = 1613
- _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 1555
- _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 1613
- _PROJECT._serialized_start = 1628
- _PROJECT._serialized_end = 1745
- _FILTER._serialized_start = 1747
- _FILTER._serialized_end = 1859
- _JOIN._serialized_start = 1862
- _JOIN._serialized_end = 2312
- _JOIN_JOINTYPE._serialized_start = 2125
- _JOIN_JOINTYPE._serialized_end = 2312
- _SETOPERATION._serialized_start = 2315
- _SETOPERATION._serialized_end = 2678
- _SETOPERATION_SETOPTYPE._serialized_start = 2564
- _SETOPERATION_SETOPTYPE._serialized_end = 2678
- _LIMIT._serialized_start = 2680
- _LIMIT._serialized_end = 2756
- _OFFSET._serialized_start = 2758
- _OFFSET._serialized_end = 2837
- _AGGREGATE._serialized_start = 2840
- _AGGREGATE._serialized_end = 3050
- _SORT._serialized_start = 3053
- _SORT._serialized_end = 3584
- _SORT_SORTFIELD._serialized_start = 3202
- _SORT_SORTFIELD._serialized_end = 3390
- _SORT_SORTDIRECTION._serialized_start = 3392
- _SORT_SORTDIRECTION._serialized_end = 3500
- _SORT_SORTNULLS._serialized_start = 3502
- _SORT_SORTNULLS._serialized_end = 3584
- _DEDUPLICATE._serialized_start = 3587
- _DEDUPLICATE._serialized_end = 3729
- _LOCALRELATION._serialized_start = 3731
- _LOCALRELATION._serialized_end = 3824
- _SAMPLE._serialized_start = 3827
- _SAMPLE._serialized_end = 4067
- _SAMPLE_SEED._serialized_start = 4041
- _SAMPLE_SEED._serialized_end = 4067
- _RANGE._serialized_start = 4070
- _RANGE._serialized_end = 4268
- _RANGE_NUMPARTITIONS._serialized_start = 4214
- _RANGE_NUMPARTITIONS._serialized_end = 4268
- _SUBQUERYALIAS._serialized_start = 4270
- _SUBQUERYALIAS._serialized_end = 4384
- _REPARTITION._serialized_start = 4386
- _REPARTITION._serialized_end = 4511
- _STATFUNCTION._serialized_start = 4514
- _STATFUNCTION._serialized_end = 4748
- _STATFUNCTION_SUMMARY._serialized_start = 4695
- _STATFUNCTION_SUMMARY._serialized_end = 4736
+ _RELATION._serialized_end = 1193
+ _UNKNOWN._serialized_start = 1195
+ _UNKNOWN._serialized_end = 1204
+ _RELATIONCOMMON._serialized_start = 1206
+ _RELATIONCOMMON._serialized_end = 1255
+ _SQL._serialized_start = 1257
+ _SQL._serialized_end = 1284
+ _READ._serialized_start = 1287
+ _READ._serialized_end = 1697
+ _READ_NAMEDTABLE._serialized_start = 1429
+ _READ_NAMEDTABLE._serialized_end = 1490
+ _READ_DATASOURCE._serialized_start = 1493
+ _READ_DATASOURCE._serialized_end = 1684
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 1626
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 1684
+ _PROJECT._serialized_start = 1699
+ _PROJECT._serialized_end = 1816
+ _FILTER._serialized_start = 1818
+ _FILTER._serialized_end = 1930
+ _JOIN._serialized_start = 1933
+ _JOIN._serialized_end = 2383
+ _JOIN_JOINTYPE._serialized_start = 2196
+ _JOIN_JOINTYPE._serialized_end = 2383
+ _SETOPERATION._serialized_start = 2386
+ _SETOPERATION._serialized_end = 2749
+ _SETOPERATION_SETOPTYPE._serialized_start = 2635
+ _SETOPERATION_SETOPTYPE._serialized_end = 2749
+ _LIMIT._serialized_start = 2751
+ _LIMIT._serialized_end = 2827
+ _OFFSET._serialized_start = 2829
+ _OFFSET._serialized_end = 2908
+ _AGGREGATE._serialized_start = 2911
+ _AGGREGATE._serialized_end = 3121
+ _SORT._serialized_start = 3124
+ _SORT._serialized_end = 3655
+ _SORT_SORTFIELD._serialized_start = 3273
+ _SORT_SORTFIELD._serialized_end = 3461
+ _SORT_SORTDIRECTION._serialized_start = 3463
+ _SORT_SORTDIRECTION._serialized_end = 3571
+ _SORT_SORTNULLS._serialized_start = 3573
+ _SORT_SORTNULLS._serialized_end = 3655
+ _DEDUPLICATE._serialized_start = 3658
+ _DEDUPLICATE._serialized_end = 3800
+ _LOCALRELATION._serialized_start = 3802
+ _LOCALRELATION._serialized_end = 3895
+ _SAMPLE._serialized_start = 3898
+ _SAMPLE._serialized_end = 4138
+ _SAMPLE_SEED._serialized_start = 4112
+ _SAMPLE_SEED._serialized_end = 4138
+ _RANGE._serialized_start = 4141
+ _RANGE._serialized_end = 4339
+ _RANGE_NUMPARTITIONS._serialized_start = 4285
+ _RANGE_NUMPARTITIONS._serialized_end = 4339
+ _SUBQUERYALIAS._serialized_start = 4341
+ _SUBQUERYALIAS._serialized_end = 4455
+ _REPARTITION._serialized_start = 4457
+ _REPARTITION._serialized_end = 4582
+ _STATFUNCTION._serialized_start = 4585
+ _STATFUNCTION._serialized_end = 4819
+ _STATFUNCTION_SUMMARY._serialized_start = 4766
+ _STATFUNCTION_SUMMARY._serialized_end = 4807
+ _RENAMECOLUMNS._serialized_start = 4821
+ _RENAMECOLUMNS._serialized_end = 4918
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 6ee3c46d7c5..bef74b03659 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -76,6 +76,7 @@ class Relation(google.protobuf.message.Message):
RANGE_FIELD_NUMBER: builtins.int
SUBQUERY_ALIAS_FIELD_NUMBER: builtins.int
REPARTITION_FIELD_NUMBER: builtins.int
+ RENAME_COLUMNS_FIELD_NUMBER: builtins.int
STAT_FUNCTION_FIELD_NUMBER: builtins.int
UNKNOWN_FIELD_NUMBER: builtins.int
@property
@@ -113,6 +114,8 @@ class Relation(google.protobuf.message.Message):
@property
def repartition(self) -> global___Repartition: ...
@property
+ def rename_columns(self) -> global___RenameColumns: ...
+ @property
def stat_function(self) -> global___StatFunction: ...
@property
def unknown(self) -> global___Unknown: ...
@@ -136,6 +139,7 @@ class Relation(google.protobuf.message.Message):
range: global___Range | None = ...,
subquery_alias: global___SubqueryAlias | None = ...,
repartition: global___Repartition | None = ...,
+ rename_columns: global___RenameColumns | None = ...,
stat_function: global___StatFunction | None = ...,
unknown: global___Unknown | None = ...,
) -> None: ...
@@ -166,6 +170,8 @@ class Relation(google.protobuf.message.Message):
b"read",
"rel_type",
b"rel_type",
+ "rename_columns",
+ b"rename_columns",
"repartition",
b"repartition",
"sample",
@@ -211,6 +217,8 @@ class Relation(google.protobuf.message.Message):
b"read",
"rel_type",
b"rel_type",
+ "rename_columns",
+ b"rename_columns",
"repartition",
b"repartition",
"sample",
@@ -248,6 +256,7 @@ class Relation(google.protobuf.message.Message):
"range",
"subquery_alias",
"repartition",
+ "rename_columns",
"stat_function",
"unknown",
] | None: ...
@@ -1133,3 +1142,38 @@ class StatFunction(google.protobuf.message.Message):
) -> typing_extensions.Literal["summary", "unknown"] | None: ...
global___StatFunction = StatFunction
+
+class RenameColumns(google.protobuf.message.Message):
+ """Rename columns on the input relation."""
+
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ INPUT_FIELD_NUMBER: builtins.int
+ COLUMN_NAMES_FIELD_NUMBER: builtins.int
+ @property
+ def input(self) -> global___Relation:
+ """Required. The input relation."""
+ @property
+ def column_names(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
+ """Required.
+
+ The number of columns of the input relation must be equal to the length
+ of this field. If this is not true, an exception will be returned.
+ """
+ def __init__(
+ self,
+ *,
+ input: global___Relation | None = ...,
+ column_names: collections.abc.Iterable[builtins.str] | None = ...,
+ ) -> None: ...
+ def HasField(
+ self, field_name: typing_extensions.Literal["input", b"input"]
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal["column_names", b"column_names", "input", b"input"],
+ ) -> None: ...
+
+global___RenameColumns = RenameColumns
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org