You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/02/13 00:45:27 UTC

[spark] branch branch-3.4 updated: [SPARK-41963][CONNECT] Fix DataFrame.unpivot to raise the same error class when the `values` argument is empty

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 92d82698760 [SPARK-41963][CONNECT] Fix DataFrame.unpivot to raise the same error class when the `values` argument is empty
92d82698760 is described below

commit 92d8269876019e8580b2e60d3b3891ac13b5740b
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Mon Feb 13 09:45:05 2023 +0900

    [SPARK-41963][CONNECT] Fix DataFrame.unpivot to raise the same error class when the `values` argument is empty
    
    ### What changes were proposed in this pull request?
    
    Fixes `DataFrame.unpivot` to raise the same error class when the `values` argument is an empty list/tuple.
    
    ### Why are the changes needed?
    
    Currently `DataFrame.unpivot` raises a different error class, `UNPIVOT_REQUIRES_VALUE_COLUMNS` for PySpark vs. `UNPIVOT_VALUE_DATA_TYPE_MISMATCH` for Spark Connect.
    
    In `Unpivot`, an empty list/tuple as `values` argument is different from `None`. It should handle them differently.
    
    ### Does this PR introduce _any_ user-facing change?
    
    `DataFrame.unpivot` on Spark Connect will raise the same error class as PySpark.
    
    ### How was this patch tested?
    
    Enabled `DataFrameParityTests.test_unpivot_negative`.
    
    Closes #39960 from ueshin/issues/SPARK-41963/unpivot.
    
    Authored-by: Takuya UESHIN <ue...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
    (cherry picked from commit 633a486c65067b483524b079810b5590ac482a48)
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 .../main/protobuf/spark/connect/relations.proto    |  6 +++-
 .../org/apache/spark/sql/connect/dsl/package.scala |  5 ++-
 .../sql/connect/planner/SparkConnectPlanner.scala  |  4 +--
 python/pyspark/sql/connect/dataframe.py            |  8 +++--
 python/pyspark/sql/connect/plan.py                 |  5 +--
 python/pyspark/sql/connect/proto/relations_pb2.py  | 25 ++++++++++----
 python/pyspark/sql/connect/proto/relations_pb2.pyi | 39 +++++++++++++++++-----
 .../pyspark/sql/tests/connect/test_connect_plan.py | 19 +++++++----
 .../sql/tests/connect/test_parity_dataframe.py     |  5 ---
 9 files changed, 83 insertions(+), 33 deletions(-)

diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 3d597fd2744..ea1216957d8 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -732,13 +732,17 @@ message Unpivot {
   repeated Expression ids = 2;
 
   // (Optional) Value columns to unpivot.
-  repeated Expression values = 3;
+  optional Values values = 3;
 
   // (Required) Name of the variable column.
   string variable_column_name = 4;
 
   // (Required) Name of the value column.
   string value_column_name = 5;
+
+  message Values {
+    repeated Expression values = 1;
+  }
 }
 
 message ToSchema {
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 88531286e24..f91040a1009 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -990,7 +990,10 @@ package object dsl {
               .newBuilder()
               .setInput(logicalPlan)
               .addAllIds(ids.asJava)
-              .addAllValues(values.asJava)
+              .setValues(Unpivot.Values
+                .newBuilder()
+                .addAllValues(values.asJava)
+                .build())
               .setVariableColumnName(variableColumnName)
               .setValueColumnName(valueColumnName))
           .build()
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 194588fe89b..740d6b85964 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -515,7 +515,7 @@ class SparkConnectPlanner(val session: SparkSession) {
       Column(transformExpression(expr))
     }
 
-    if (rel.getValuesList.isEmpty) {
+    if (!rel.hasValues) {
       Unpivot(
         Some(ids.map(_.named)),
         None,
@@ -524,7 +524,7 @@ class SparkConnectPlanner(val session: SparkSession) {
         Seq(rel.getValueColumnName),
         transformRelation(rel.getInput))
     } else {
-      val values = rel.getValuesList.asScala.toArray.map { expr =>
+      val values = rel.getValues.getValuesList.asScala.toArray.map { expr =>
         Column(transformExpression(expr))
       }
 
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index 536c909883e..95e39f93dc0 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -676,7 +676,7 @@ class DataFrame:
 
     def unpivot(
         self,
-        ids: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]],
+        ids: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]],
         values: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]],
         variableColumnName: str,
         valueColumnName: str,
@@ -698,7 +698,11 @@ class DataFrame:
 
         return DataFrame.withPlan(
             plan.Unpivot(
-                self._plan, to_jcols(ids), to_jcols(values), variableColumnName, valueColumnName
+                self._plan,
+                to_jcols(ids),
+                to_jcols(values) if values is not None else None,
+                variableColumnName,
+                valueColumnName,
             ),
             self._session,
         )
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index 4675631627a..ced0e4008e1 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1001,7 +1001,7 @@ class Unpivot(LogicalPlan):
         self,
         child: Optional["LogicalPlan"],
         ids: List["ColumnOrName"],
-        values: List["ColumnOrName"],
+        values: Optional[List["ColumnOrName"]],
         variable_column_name: str,
         value_column_name: str,
     ) -> None:
@@ -1023,7 +1023,8 @@ class Unpivot(LogicalPlan):
         plan = proto.Relation()
         plan.unpivot.input.CopyFrom(self._child.plan(session))
         plan.unpivot.ids.extend([self.col_to_expr(x, session) for x in self.ids])
-        plan.unpivot.values.extend([self.col_to_expr(x, session) for x in self.values])
+        if self.values is not None:
+            plan.unpivot.values.values.extend([self.col_to_expr(x, session) for x in self.values])
         plan.unpivot.variable_column_name = self.variable_column_name
         plan.unpivot.value_column_name = self.value_column_name
         return plan
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py
index 1a24628ef30..ece5920953e 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as spark_dot_connect_dot_catal
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf9\x11\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+    b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf9\x11\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
 )
 
 
@@ -88,6 +88,7 @@ _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY = _WITHCOLUMNSRENAMED.nested_types_by_
 _WITHCOLUMNS = DESCRIPTOR.message_types_by_name["WithColumns"]
 _HINT = DESCRIPTOR.message_types_by_name["Hint"]
 _UNPIVOT = DESCRIPTOR.message_types_by_name["Unpivot"]
+_UNPIVOT_VALUES = _UNPIVOT.nested_types_by_name["Values"]
 _TOSCHEMA = DESCRIPTOR.message_types_by_name["ToSchema"]
 _REPARTITIONBYEXPRESSION = DESCRIPTOR.message_types_by_name["RepartitionByExpression"]
 _JOIN_JOINTYPE = _JOIN.enum_types_by_name["JoinType"]
@@ -584,12 +585,22 @@ Unpivot = _reflection.GeneratedProtocolMessageType(
     "Unpivot",
     (_message.Message,),
     {
+        "Values": _reflection.GeneratedProtocolMessageType(
+            "Values",
+            (_message.Message,),
+            {
+                "DESCRIPTOR": _UNPIVOT_VALUES,
+                "__module__": "spark.connect.relations_pb2"
+                # @@protoc_insertion_point(class_scope:spark.connect.Unpivot.Values)
+            },
+        ),
         "DESCRIPTOR": _UNPIVOT,
         "__module__": "spark.connect.relations_pb2"
         # @@protoc_insertion_point(class_scope:spark.connect.Unpivot)
     },
 )
 _sym_db.RegisterMessage(Unpivot)
+_sym_db.RegisterMessage(Unpivot.Values)
 
 ToSchema = _reflection.GeneratedProtocolMessageType(
     "ToSchema",
@@ -720,9 +731,11 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _HINT._serialized_start = 8635
     _HINT._serialized_end = 8767
     _UNPIVOT._serialized_start = 8770
-    _UNPIVOT._serialized_end = 9016
-    _TOSCHEMA._serialized_start = 9018
-    _TOSCHEMA._serialized_end = 9124
-    _REPARTITIONBYEXPRESSION._serialized_start = 9127
-    _REPARTITIONBYEXPRESSION._serialized_end = 9330
+    _UNPIVOT._serialized_end = 9097
+    _UNPIVOT_VALUES._serialized_start = 9027
+    _UNPIVOT_VALUES._serialized_end = 9086
+    _TOSCHEMA._serialized_start = 9099
+    _TOSCHEMA._serialized_end = 9205
+    _REPARTITIONBYEXPRESSION._serialized_start = 9208
+    _REPARTITIONBYEXPRESSION._serialized_end = 9411
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 647b26b6d31..41962ee4062 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -2459,6 +2459,26 @@ class Unpivot(google.protobuf.message.Message):
 
     DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
+    class Values(google.protobuf.message.Message):
+        DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+        VALUES_FIELD_NUMBER: builtins.int
+        @property
+        def values(
+            self,
+        ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+            pyspark.sql.connect.proto.expressions_pb2.Expression
+        ]: ...
+        def __init__(
+            self,
+            *,
+            values: collections.abc.Iterable[pyspark.sql.connect.proto.expressions_pb2.Expression]
+            | None = ...,
+        ) -> None: ...
+        def ClearField(
+            self, field_name: typing_extensions.Literal["values", b"values"]
+        ) -> None: ...
+
     INPUT_FIELD_NUMBER: builtins.int
     IDS_FIELD_NUMBER: builtins.int
     VALUES_FIELD_NUMBER: builtins.int
@@ -2475,11 +2495,7 @@ class Unpivot(google.protobuf.message.Message):
     ]:
         """(Required) Id columns."""
     @property
-    def values(
-        self,
-    ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
-        pyspark.sql.connect.proto.expressions_pb2.Expression
-    ]:
+    def values(self) -> global___Unpivot.Values:
         """(Optional) Value columns to unpivot."""
     variable_column_name: builtins.str
     """(Required) Name of the variable column."""
@@ -2491,17 +2507,21 @@ class Unpivot(google.protobuf.message.Message):
         input: global___Relation | None = ...,
         ids: collections.abc.Iterable[pyspark.sql.connect.proto.expressions_pb2.Expression]
         | None = ...,
-        values: collections.abc.Iterable[pyspark.sql.connect.proto.expressions_pb2.Expression]
-        | None = ...,
+        values: global___Unpivot.Values | None = ...,
         variable_column_name: builtins.str = ...,
         value_column_name: builtins.str = ...,
     ) -> None: ...
     def HasField(
-        self, field_name: typing_extensions.Literal["input", b"input"]
+        self,
+        field_name: typing_extensions.Literal[
+            "_values", b"_values", "input", b"input", "values", b"values"
+        ],
     ) -> builtins.bool: ...
     def ClearField(
         self,
         field_name: typing_extensions.Literal[
+            "_values",
+            b"_values",
             "ids",
             b"ids",
             "input",
@@ -2514,6 +2534,9 @@ class Unpivot(google.protobuf.message.Message):
             b"variable_column_name",
         ],
     ) -> None: ...
+    def WhichOneof(
+        self, oneof_group: typing_extensions.Literal["_values", b"_values"]
+    ) -> typing_extensions.Literal["values"] | None: ...
 
 global___Unpivot = Unpivot
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py b/python/pyspark/sql/tests/connect/test_connect_plan.py
index 980f61eb4b1..1892e64f8f9 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan.py
@@ -192,9 +192,12 @@ class SparkConnectPlanTests(PlanOnlyTestFixture):
         )
         self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.ids))
         self.assertEqual(plan.root.unpivot.ids[0].unresolved_attribute.unparsed_identifier, "id")
-        self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.values))
+        self.assertEqual(plan.root.unpivot.HasField("values"), True)
+        self.assertTrue(
+            all(isinstance(c, proto.Expression) for c in plan.root.unpivot.values.values)
+        )
         self.assertEqual(
-            plan.root.unpivot.values[0].unresolved_attribute.unparsed_identifier, "name"
+            plan.root.unpivot.values.values[0].unresolved_attribute.unparsed_identifier, "name"
         )
         self.assertEqual(plan.root.unpivot.variable_column_name, "variable")
         self.assertEqual(plan.root.unpivot.value_column_name, "value")
@@ -207,7 +210,7 @@ class SparkConnectPlanTests(PlanOnlyTestFixture):
         self.assertTrue(len(plan.root.unpivot.ids) == 1)
         self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.ids))
         self.assertEqual(plan.root.unpivot.ids[0].unresolved_attribute.unparsed_identifier, "id")
-        self.assertTrue(len(plan.root.unpivot.values) == 0)
+        self.assertEqual(plan.root.unpivot.HasField("values"), False)
         self.assertEqual(plan.root.unpivot.variable_column_name, "variable")
         self.assertEqual(plan.root.unpivot.value_column_name, "value")
 
@@ -221,9 +224,12 @@ class SparkConnectPlanTests(PlanOnlyTestFixture):
         )
         self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.ids))
         self.assertEqual(plan.root.unpivot.ids[0].unresolved_attribute.unparsed_identifier, "id")
-        self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.values))
+        self.assertEqual(plan.root.unpivot.HasField("values"), True)
+        self.assertTrue(
+            all(isinstance(c, proto.Expression) for c in plan.root.unpivot.values.values)
+        )
         self.assertEqual(
-            plan.root.unpivot.values[0].unresolved_attribute.unparsed_identifier, "name"
+            plan.root.unpivot.values.values[0].unresolved_attribute.unparsed_identifier, "name"
         )
         self.assertEqual(plan.root.unpivot.variable_column_name, "variable")
         self.assertEqual(plan.root.unpivot.value_column_name, "value")
@@ -236,7 +242,8 @@ class SparkConnectPlanTests(PlanOnlyTestFixture):
         self.assertTrue(len(plan.root.unpivot.ids) == 1)
         self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.ids))
         self.assertEqual(plan.root.unpivot.ids[0].unresolved_attribute.unparsed_identifier, "id")
-        self.assertTrue(len(plan.root.unpivot.values) == 0)
+        self.assertEqual(plan.root.unpivot.HasField("values"), True)
+        self.assertTrue(len(plan.root.unpivot.values.values) == 0)
         self.assertEqual(plan.root.unpivot.variable_column_name, "variable")
         self.assertEqual(plan.root.unpivot.value_column_name, "value")
 
diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
index 8413dbaf06d..07cae0fb27d 100644
--- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
+++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
@@ -142,11 +142,6 @@ class DataFrameParityTests(DataFrameTestsMixin, ReusedConnectTestCase):
     def test_to_pandas_with_duplicated_column_names(self):
         super().test_to_pandas_with_duplicated_column_names()
 
-    # TODO(SPARK-41963): Different exception message in DataFrame.unpivot
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_unpivot_negative(self):
-        super().test_unpivot_negative()
-
 
 if __name__ == "__main__":
     import unittest


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org