You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/11/09 07:48:44 UTC

[spark] branch master updated: [SPARK-40992][CONNECT] Support toDF(columnNames) in Connect DSL

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e1382c566b7 [SPARK-40992][CONNECT] Support toDF(columnNames) in Connect DSL
e1382c566b7 is described below

commit e1382c566b7b2ba324fec1aed6556325ebe43f7b
Author: Rui Wang <ru...@databricks.com>
AuthorDate: Wed Nov 9 15:48:24 2022 +0800

    [SPARK-40992][CONNECT] Support toDF(columnNames) in Connect DSL
    
    ### What changes were proposed in this pull request?
    
    Add `RenameColumns` to proto to support the implementation for `toDF(columnNames: String*)` which renames the input relation to a different set of column names.
    
    ### Why are the changes needed?
    
    Improve API coverage.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    UT
    
    Closes #38475 from amaliujia/SPARK-40992.
    
    Authored-by: Rui Wang <ru...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../main/protobuf/spark/connect/relations.proto    |  12 ++
 .../org/apache/spark/sql/connect/dsl/package.scala |  10 ++
 .../sql/connect/planner/SparkConnectPlanner.scala  |   9 ++
 .../connect/planner/SparkConnectProtoSuite.scala   |   4 +
 python/pyspark/sql/connect/proto/relations_pb2.py  | 126 +++++++++++----------
 python/pyspark/sql/connect/proto/relations_pb2.pyi |  44 +++++++
 6 files changed, 143 insertions(+), 62 deletions(-)

diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto
index dd03bd86940..cce9f3b939e 100644
--- a/connector/connect/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto
@@ -47,6 +47,7 @@ message Relation {
     Range range = 15;
     SubqueryAlias subquery_alias = 16;
     Repartition repartition = 17;
+    RenameColumns rename_columns = 18;
 
     StatFunction stat_function = 100;
 
@@ -274,3 +275,14 @@ message StatFunction {
   }
 }
 
+// Rename columns on the input relation.
+message RenameColumns {
+  // Required. The input relation.
+  Relation input = 1;
+
+  // Required.
+  //
+  // The number of columns of the input relation must be equal to the length
+  // of this field. If this is not true, an exception will be returned.
+  repeated string column_names = 2;
+}
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 3e68b101057..d6f7a6756c3 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -457,6 +457,16 @@ package object dsl {
           .build()
       }
 
+      def toDF(columnNames: String*): Relation =
+        Relation
+          .newBuilder()
+          .setRenameColumns(
+            RenameColumns
+              .newBuilder()
+              .setInput(logicalPlan)
+              .addAllColumnNames(columnNames.asJava))
+          .build()
+
       private def createSetOperation(
           left: Relation,
           right: Relation,
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 3bbdbf80276..87716c702b5 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -69,6 +69,8 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
       case proto.Relation.RelTypeCase.REPARTITION => transformRepartition(rel.getRepartition)
       case proto.Relation.RelTypeCase.STAT_FUNCTION =>
         transformStatFunction(rel.getStatFunction)
+      case proto.Relation.RelTypeCase.RENAME_COLUMNS =>
+        transformRenameColumns(rel.getRenameColumns)
       case proto.Relation.RelTypeCase.RELTYPE_NOT_SET =>
         throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.")
       case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.")
@@ -133,6 +135,13 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
     }
   }
 
+  private def transformRenameColumns(rel: proto.RenameColumns): LogicalPlan = {
+    Dataset
+      .ofRows(session, transformRelation(rel.getInput))
+      .toDF(rel.getColumnNamesList.asScala.toSeq: _*)
+      .logicalPlan
+  }
+
   private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = {
     if (!rel.hasInput) {
       throw InvalidPlanInput("Deduplicate needs a plan input")
diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index c5b6f4fc0ee..2339c676a38 100644
--- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -267,6 +267,10 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
       sparkTestRelation.summary("count", "mean", "stddev"))
   }
 
+  test("Test toDF") {
+    comparePlans(connectTestRelation.toDF("col1", "col2"), sparkTestRelation.toDF("col1", "col2"))
+  }
+
   private def createLocalRelationProtoByQualifiedAttributes(
       attrs: Seq[proto.Expression.QualifiedAttribute]): proto.Relation = {
     val localRelationBuilder = proto.LocalRelation.newBuilder()
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py
index b11a4b0e91a..06b59ea5f45 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -32,7 +32,7 @@ from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_e
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\x90\x08\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...]
+    b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xd7\x08\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0 [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -44,65 +44,67 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _READ_DATASOURCE_OPTIONSENTRY._options = None
     _READ_DATASOURCE_OPTIONSENTRY._serialized_options = b"8\001"
     _RELATION._serialized_start = 82
-    _RELATION._serialized_end = 1122
-    _UNKNOWN._serialized_start = 1124
-    _UNKNOWN._serialized_end = 1133
-    _RELATIONCOMMON._serialized_start = 1135
-    _RELATIONCOMMON._serialized_end = 1184
-    _SQL._serialized_start = 1186
-    _SQL._serialized_end = 1213
-    _READ._serialized_start = 1216
-    _READ._serialized_end = 1626
-    _READ_NAMEDTABLE._serialized_start = 1358
-    _READ_NAMEDTABLE._serialized_end = 1419
-    _READ_DATASOURCE._serialized_start = 1422
-    _READ_DATASOURCE._serialized_end = 1613
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 1555
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 1613
-    _PROJECT._serialized_start = 1628
-    _PROJECT._serialized_end = 1745
-    _FILTER._serialized_start = 1747
-    _FILTER._serialized_end = 1859
-    _JOIN._serialized_start = 1862
-    _JOIN._serialized_end = 2312
-    _JOIN_JOINTYPE._serialized_start = 2125
-    _JOIN_JOINTYPE._serialized_end = 2312
-    _SETOPERATION._serialized_start = 2315
-    _SETOPERATION._serialized_end = 2678
-    _SETOPERATION_SETOPTYPE._serialized_start = 2564
-    _SETOPERATION_SETOPTYPE._serialized_end = 2678
-    _LIMIT._serialized_start = 2680
-    _LIMIT._serialized_end = 2756
-    _OFFSET._serialized_start = 2758
-    _OFFSET._serialized_end = 2837
-    _AGGREGATE._serialized_start = 2840
-    _AGGREGATE._serialized_end = 3050
-    _SORT._serialized_start = 3053
-    _SORT._serialized_end = 3584
-    _SORT_SORTFIELD._serialized_start = 3202
-    _SORT_SORTFIELD._serialized_end = 3390
-    _SORT_SORTDIRECTION._serialized_start = 3392
-    _SORT_SORTDIRECTION._serialized_end = 3500
-    _SORT_SORTNULLS._serialized_start = 3502
-    _SORT_SORTNULLS._serialized_end = 3584
-    _DEDUPLICATE._serialized_start = 3587
-    _DEDUPLICATE._serialized_end = 3729
-    _LOCALRELATION._serialized_start = 3731
-    _LOCALRELATION._serialized_end = 3824
-    _SAMPLE._serialized_start = 3827
-    _SAMPLE._serialized_end = 4067
-    _SAMPLE_SEED._serialized_start = 4041
-    _SAMPLE_SEED._serialized_end = 4067
-    _RANGE._serialized_start = 4070
-    _RANGE._serialized_end = 4268
-    _RANGE_NUMPARTITIONS._serialized_start = 4214
-    _RANGE_NUMPARTITIONS._serialized_end = 4268
-    _SUBQUERYALIAS._serialized_start = 4270
-    _SUBQUERYALIAS._serialized_end = 4384
-    _REPARTITION._serialized_start = 4386
-    _REPARTITION._serialized_end = 4511
-    _STATFUNCTION._serialized_start = 4514
-    _STATFUNCTION._serialized_end = 4748
-    _STATFUNCTION_SUMMARY._serialized_start = 4695
-    _STATFUNCTION_SUMMARY._serialized_end = 4736
+    _RELATION._serialized_end = 1193
+    _UNKNOWN._serialized_start = 1195
+    _UNKNOWN._serialized_end = 1204
+    _RELATIONCOMMON._serialized_start = 1206
+    _RELATIONCOMMON._serialized_end = 1255
+    _SQL._serialized_start = 1257
+    _SQL._serialized_end = 1284
+    _READ._serialized_start = 1287
+    _READ._serialized_end = 1697
+    _READ_NAMEDTABLE._serialized_start = 1429
+    _READ_NAMEDTABLE._serialized_end = 1490
+    _READ_DATASOURCE._serialized_start = 1493
+    _READ_DATASOURCE._serialized_end = 1684
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 1626
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 1684
+    _PROJECT._serialized_start = 1699
+    _PROJECT._serialized_end = 1816
+    _FILTER._serialized_start = 1818
+    _FILTER._serialized_end = 1930
+    _JOIN._serialized_start = 1933
+    _JOIN._serialized_end = 2383
+    _JOIN_JOINTYPE._serialized_start = 2196
+    _JOIN_JOINTYPE._serialized_end = 2383
+    _SETOPERATION._serialized_start = 2386
+    _SETOPERATION._serialized_end = 2749
+    _SETOPERATION_SETOPTYPE._serialized_start = 2635
+    _SETOPERATION_SETOPTYPE._serialized_end = 2749
+    _LIMIT._serialized_start = 2751
+    _LIMIT._serialized_end = 2827
+    _OFFSET._serialized_start = 2829
+    _OFFSET._serialized_end = 2908
+    _AGGREGATE._serialized_start = 2911
+    _AGGREGATE._serialized_end = 3121
+    _SORT._serialized_start = 3124
+    _SORT._serialized_end = 3655
+    _SORT_SORTFIELD._serialized_start = 3273
+    _SORT_SORTFIELD._serialized_end = 3461
+    _SORT_SORTDIRECTION._serialized_start = 3463
+    _SORT_SORTDIRECTION._serialized_end = 3571
+    _SORT_SORTNULLS._serialized_start = 3573
+    _SORT_SORTNULLS._serialized_end = 3655
+    _DEDUPLICATE._serialized_start = 3658
+    _DEDUPLICATE._serialized_end = 3800
+    _LOCALRELATION._serialized_start = 3802
+    _LOCALRELATION._serialized_end = 3895
+    _SAMPLE._serialized_start = 3898
+    _SAMPLE._serialized_end = 4138
+    _SAMPLE_SEED._serialized_start = 4112
+    _SAMPLE_SEED._serialized_end = 4138
+    _RANGE._serialized_start = 4141
+    _RANGE._serialized_end = 4339
+    _RANGE_NUMPARTITIONS._serialized_start = 4285
+    _RANGE_NUMPARTITIONS._serialized_end = 4339
+    _SUBQUERYALIAS._serialized_start = 4341
+    _SUBQUERYALIAS._serialized_end = 4455
+    _REPARTITION._serialized_start = 4457
+    _REPARTITION._serialized_end = 4582
+    _STATFUNCTION._serialized_start = 4585
+    _STATFUNCTION._serialized_end = 4819
+    _STATFUNCTION_SUMMARY._serialized_start = 4766
+    _STATFUNCTION_SUMMARY._serialized_end = 4807
+    _RENAMECOLUMNS._serialized_start = 4821
+    _RENAMECOLUMNS._serialized_end = 4918
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 6ee3c46d7c5..bef74b03659 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -76,6 +76,7 @@ class Relation(google.protobuf.message.Message):
     RANGE_FIELD_NUMBER: builtins.int
     SUBQUERY_ALIAS_FIELD_NUMBER: builtins.int
     REPARTITION_FIELD_NUMBER: builtins.int
+    RENAME_COLUMNS_FIELD_NUMBER: builtins.int
     STAT_FUNCTION_FIELD_NUMBER: builtins.int
     UNKNOWN_FIELD_NUMBER: builtins.int
     @property
@@ -113,6 +114,8 @@ class Relation(google.protobuf.message.Message):
     @property
     def repartition(self) -> global___Repartition: ...
     @property
+    def rename_columns(self) -> global___RenameColumns: ...
+    @property
     def stat_function(self) -> global___StatFunction: ...
     @property
     def unknown(self) -> global___Unknown: ...
@@ -136,6 +139,7 @@ class Relation(google.protobuf.message.Message):
         range: global___Range | None = ...,
         subquery_alias: global___SubqueryAlias | None = ...,
         repartition: global___Repartition | None = ...,
+        rename_columns: global___RenameColumns | None = ...,
         stat_function: global___StatFunction | None = ...,
         unknown: global___Unknown | None = ...,
     ) -> None: ...
@@ -166,6 +170,8 @@ class Relation(google.protobuf.message.Message):
             b"read",
             "rel_type",
             b"rel_type",
+            "rename_columns",
+            b"rename_columns",
             "repartition",
             b"repartition",
             "sample",
@@ -211,6 +217,8 @@ class Relation(google.protobuf.message.Message):
             b"read",
             "rel_type",
             b"rel_type",
+            "rename_columns",
+            b"rename_columns",
             "repartition",
             b"repartition",
             "sample",
@@ -248,6 +256,7 @@ class Relation(google.protobuf.message.Message):
         "range",
         "subquery_alias",
         "repartition",
+        "rename_columns",
         "stat_function",
         "unknown",
     ] | None: ...
@@ -1133,3 +1142,38 @@ class StatFunction(google.protobuf.message.Message):
     ) -> typing_extensions.Literal["summary", "unknown"] | None: ...
 
 global___StatFunction = StatFunction
+
+class RenameColumns(google.protobuf.message.Message):
+    """Rename columns on the input relation."""
+
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    INPUT_FIELD_NUMBER: builtins.int
+    COLUMN_NAMES_FIELD_NUMBER: builtins.int
+    @property
+    def input(self) -> global___Relation:
+        """Required. The input relation."""
+    @property
+    def column_names(
+        self,
+    ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
+        """Required.
+
+        The number of columns of the input relation must be equal to the length
+        of this field. If this is not true, an exception will be returned.
+        """
+    def __init__(
+        self,
+        *,
+        input: global___Relation | None = ...,
+        column_names: collections.abc.Iterable[builtins.str] | None = ...,
+    ) -> None: ...
+    def HasField(
+        self, field_name: typing_extensions.Literal["input", b"input"]
+    ) -> builtins.bool: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal["column_names", b"column_names", "input", b"input"],
+    ) -> None: ...
+
+global___RenameColumns = RenameColumns


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org