You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/02/13 00:45:27 UTC
[spark] branch branch-3.4 updated: [SPARK-41963][CONNECT] Fix DataFrame.unpivot to raise the same error class when the `values` argument is empty
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new 92d82698760 [SPARK-41963][CONNECT] Fix DataFrame.unpivot to raise the same error class when the `values` argument is empty
92d82698760 is described below
commit 92d8269876019e8580b2e60d3b3891ac13b5740b
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Mon Feb 13 09:45:05 2023 +0900
[SPARK-41963][CONNECT] Fix DataFrame.unpivot to raise the same error class when the `values` argument is empty
### What changes were proposed in this pull request?
Fixes `DataFrame.unpivot` to raise the same error class when the `values` argument is an empty list/tuple.
### Why are the changes needed?
Currently `DataFrame.unpivot` raises a different error class, `UNPIVOT_REQUIRES_VALUE_COLUMNS` for PySpark vs. `UNPIVOT_VALUE_DATA_TYPE_MISMATCH` for Spark Connect.
In `Unpivot`, an empty list/tuple as `values` argument is different from `None`. It should handle them differently.
### Does this PR introduce _any_ user-facing change?
`DataFrame.unpivot` on Spark Connect will raise the same error class as PySpark.
### How was this patch tested?
Enabled `DataFrameParityTests.test_unpivot_negative`.
Closes #39960 from ueshin/issues/SPARK-41963/unpivot.
Authored-by: Takuya UESHIN <ue...@databricks.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
(cherry picked from commit 633a486c65067b483524b079810b5590ac482a48)
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
.../main/protobuf/spark/connect/relations.proto | 6 +++-
.../org/apache/spark/sql/connect/dsl/package.scala | 5 ++-
.../sql/connect/planner/SparkConnectPlanner.scala | 4 +--
python/pyspark/sql/connect/dataframe.py | 8 +++--
python/pyspark/sql/connect/plan.py | 5 +--
python/pyspark/sql/connect/proto/relations_pb2.py | 25 ++++++++++----
python/pyspark/sql/connect/proto/relations_pb2.pyi | 39 +++++++++++++++++-----
.../pyspark/sql/tests/connect/test_connect_plan.py | 19 +++++++----
.../sql/tests/connect/test_parity_dataframe.py | 5 ---
9 files changed, 83 insertions(+), 33 deletions(-)
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 3d597fd2744..ea1216957d8 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -732,13 +732,17 @@ message Unpivot {
repeated Expression ids = 2;
// (Optional) Value columns to unpivot.
- repeated Expression values = 3;
+ optional Values values = 3;
// (Required) Name of the variable column.
string variable_column_name = 4;
// (Required) Name of the value column.
string value_column_name = 5;
+
+ message Values {
+ repeated Expression values = 1;
+ }
}
message ToSchema {
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 88531286e24..f91040a1009 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -990,7 +990,10 @@ package object dsl {
.newBuilder()
.setInput(logicalPlan)
.addAllIds(ids.asJava)
- .addAllValues(values.asJava)
+ .setValues(Unpivot.Values
+ .newBuilder()
+ .addAllValues(values.asJava)
+ .build())
.setVariableColumnName(variableColumnName)
.setValueColumnName(valueColumnName))
.build()
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 194588fe89b..740d6b85964 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -515,7 +515,7 @@ class SparkConnectPlanner(val session: SparkSession) {
Column(transformExpression(expr))
}
- if (rel.getValuesList.isEmpty) {
+ if (!rel.hasValues) {
Unpivot(
Some(ids.map(_.named)),
None,
@@ -524,7 +524,7 @@ class SparkConnectPlanner(val session: SparkSession) {
Seq(rel.getValueColumnName),
transformRelation(rel.getInput))
} else {
- val values = rel.getValuesList.asScala.toArray.map { expr =>
+ val values = rel.getValues.getValuesList.asScala.toArray.map { expr =>
Column(transformExpression(expr))
}
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index 536c909883e..95e39f93dc0 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -676,7 +676,7 @@ class DataFrame:
def unpivot(
self,
- ids: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]],
+ ids: Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]],
values: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]],
variableColumnName: str,
valueColumnName: str,
@@ -698,7 +698,11 @@ class DataFrame:
return DataFrame.withPlan(
plan.Unpivot(
- self._plan, to_jcols(ids), to_jcols(values), variableColumnName, valueColumnName
+ self._plan,
+ to_jcols(ids),
+ to_jcols(values) if values is not None else None,
+ variableColumnName,
+ valueColumnName,
),
self._session,
)
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index 4675631627a..ced0e4008e1 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1001,7 +1001,7 @@ class Unpivot(LogicalPlan):
self,
child: Optional["LogicalPlan"],
ids: List["ColumnOrName"],
- values: List["ColumnOrName"],
+ values: Optional[List["ColumnOrName"]],
variable_column_name: str,
value_column_name: str,
) -> None:
@@ -1023,7 +1023,8 @@ class Unpivot(LogicalPlan):
plan = proto.Relation()
plan.unpivot.input.CopyFrom(self._child.plan(session))
plan.unpivot.ids.extend([self.col_to_expr(x, session) for x in self.ids])
- plan.unpivot.values.extend([self.col_to_expr(x, session) for x in self.values])
+ if self.values is not None:
+ plan.unpivot.values.values.extend([self.col_to_expr(x, session) for x in self.values])
plan.unpivot.variable_column_name = self.variable_column_name
plan.unpivot.value_column_name = self.value_column_name
return plan
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py
index 1a24628ef30..ece5920953e 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as spark_dot_connect_dot_catal
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf9\x11\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+ b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf9\x11\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
)
@@ -88,6 +88,7 @@ _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY = _WITHCOLUMNSRENAMED.nested_types_by_
_WITHCOLUMNS = DESCRIPTOR.message_types_by_name["WithColumns"]
_HINT = DESCRIPTOR.message_types_by_name["Hint"]
_UNPIVOT = DESCRIPTOR.message_types_by_name["Unpivot"]
+_UNPIVOT_VALUES = _UNPIVOT.nested_types_by_name["Values"]
_TOSCHEMA = DESCRIPTOR.message_types_by_name["ToSchema"]
_REPARTITIONBYEXPRESSION = DESCRIPTOR.message_types_by_name["RepartitionByExpression"]
_JOIN_JOINTYPE = _JOIN.enum_types_by_name["JoinType"]
@@ -584,12 +585,22 @@ Unpivot = _reflection.GeneratedProtocolMessageType(
"Unpivot",
(_message.Message,),
{
+ "Values": _reflection.GeneratedProtocolMessageType(
+ "Values",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _UNPIVOT_VALUES,
+ "__module__": "spark.connect.relations_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.Unpivot.Values)
+ },
+ ),
"DESCRIPTOR": _UNPIVOT,
"__module__": "spark.connect.relations_pb2"
# @@protoc_insertion_point(class_scope:spark.connect.Unpivot)
},
)
_sym_db.RegisterMessage(Unpivot)
+_sym_db.RegisterMessage(Unpivot.Values)
ToSchema = _reflection.GeneratedProtocolMessageType(
"ToSchema",
@@ -720,9 +731,11 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_HINT._serialized_start = 8635
_HINT._serialized_end = 8767
_UNPIVOT._serialized_start = 8770
- _UNPIVOT._serialized_end = 9016
- _TOSCHEMA._serialized_start = 9018
- _TOSCHEMA._serialized_end = 9124
- _REPARTITIONBYEXPRESSION._serialized_start = 9127
- _REPARTITIONBYEXPRESSION._serialized_end = 9330
+ _UNPIVOT._serialized_end = 9097
+ _UNPIVOT_VALUES._serialized_start = 9027
+ _UNPIVOT_VALUES._serialized_end = 9086
+ _TOSCHEMA._serialized_start = 9099
+ _TOSCHEMA._serialized_end = 9205
+ _REPARTITIONBYEXPRESSION._serialized_start = 9208
+ _REPARTITIONBYEXPRESSION._serialized_end = 9411
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 647b26b6d31..41962ee4062 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -2459,6 +2459,26 @@ class Unpivot(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
+ class Values(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ VALUES_FIELD_NUMBER: builtins.int
+ @property
+ def values(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+ pyspark.sql.connect.proto.expressions_pb2.Expression
+ ]: ...
+ def __init__(
+ self,
+ *,
+ values: collections.abc.Iterable[pyspark.sql.connect.proto.expressions_pb2.Expression]
+ | None = ...,
+ ) -> None: ...
+ def ClearField(
+ self, field_name: typing_extensions.Literal["values", b"values"]
+ ) -> None: ...
+
INPUT_FIELD_NUMBER: builtins.int
IDS_FIELD_NUMBER: builtins.int
VALUES_FIELD_NUMBER: builtins.int
@@ -2475,11 +2495,7 @@ class Unpivot(google.protobuf.message.Message):
]:
"""(Required) Id columns."""
@property
- def values(
- self,
- ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
- pyspark.sql.connect.proto.expressions_pb2.Expression
- ]:
+ def values(self) -> global___Unpivot.Values:
"""(Optional) Value columns to unpivot."""
variable_column_name: builtins.str
"""(Required) Name of the variable column."""
@@ -2491,17 +2507,21 @@ class Unpivot(google.protobuf.message.Message):
input: global___Relation | None = ...,
ids: collections.abc.Iterable[pyspark.sql.connect.proto.expressions_pb2.Expression]
| None = ...,
- values: collections.abc.Iterable[pyspark.sql.connect.proto.expressions_pb2.Expression]
- | None = ...,
+ values: global___Unpivot.Values | None = ...,
variable_column_name: builtins.str = ...,
value_column_name: builtins.str = ...,
) -> None: ...
def HasField(
- self, field_name: typing_extensions.Literal["input", b"input"]
+ self,
+ field_name: typing_extensions.Literal[
+ "_values", b"_values", "input", b"input", "values", b"values"
+ ],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
+ "_values",
+ b"_values",
"ids",
b"ids",
"input",
@@ -2514,6 +2534,9 @@ class Unpivot(google.protobuf.message.Message):
b"variable_column_name",
],
) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_values", b"_values"]
+ ) -> typing_extensions.Literal["values"] | None: ...
global___Unpivot = Unpivot
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py b/python/pyspark/sql/tests/connect/test_connect_plan.py
index 980f61eb4b1..1892e64f8f9 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan.py
@@ -192,9 +192,12 @@ class SparkConnectPlanTests(PlanOnlyTestFixture):
)
self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.ids))
self.assertEqual(plan.root.unpivot.ids[0].unresolved_attribute.unparsed_identifier, "id")
- self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.values))
+ self.assertEqual(plan.root.unpivot.HasField("values"), True)
+ self.assertTrue(
+ all(isinstance(c, proto.Expression) for c in plan.root.unpivot.values.values)
+ )
self.assertEqual(
- plan.root.unpivot.values[0].unresolved_attribute.unparsed_identifier, "name"
+ plan.root.unpivot.values.values[0].unresolved_attribute.unparsed_identifier, "name"
)
self.assertEqual(plan.root.unpivot.variable_column_name, "variable")
self.assertEqual(plan.root.unpivot.value_column_name, "value")
@@ -207,7 +210,7 @@ class SparkConnectPlanTests(PlanOnlyTestFixture):
self.assertTrue(len(plan.root.unpivot.ids) == 1)
self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.ids))
self.assertEqual(plan.root.unpivot.ids[0].unresolved_attribute.unparsed_identifier, "id")
- self.assertTrue(len(plan.root.unpivot.values) == 0)
+ self.assertEqual(plan.root.unpivot.HasField("values"), False)
self.assertEqual(plan.root.unpivot.variable_column_name, "variable")
self.assertEqual(plan.root.unpivot.value_column_name, "value")
@@ -221,9 +224,12 @@ class SparkConnectPlanTests(PlanOnlyTestFixture):
)
self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.ids))
self.assertEqual(plan.root.unpivot.ids[0].unresolved_attribute.unparsed_identifier, "id")
- self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.values))
+ self.assertEqual(plan.root.unpivot.HasField("values"), True)
+ self.assertTrue(
+ all(isinstance(c, proto.Expression) for c in plan.root.unpivot.values.values)
+ )
self.assertEqual(
- plan.root.unpivot.values[0].unresolved_attribute.unparsed_identifier, "name"
+ plan.root.unpivot.values.values[0].unresolved_attribute.unparsed_identifier, "name"
)
self.assertEqual(plan.root.unpivot.variable_column_name, "variable")
self.assertEqual(plan.root.unpivot.value_column_name, "value")
@@ -236,7 +242,8 @@ class SparkConnectPlanTests(PlanOnlyTestFixture):
self.assertTrue(len(plan.root.unpivot.ids) == 1)
self.assertTrue(all(isinstance(c, proto.Expression) for c in plan.root.unpivot.ids))
self.assertEqual(plan.root.unpivot.ids[0].unresolved_attribute.unparsed_identifier, "id")
- self.assertTrue(len(plan.root.unpivot.values) == 0)
+ self.assertEqual(plan.root.unpivot.HasField("values"), True)
+ self.assertTrue(len(plan.root.unpivot.values.values) == 0)
self.assertEqual(plan.root.unpivot.variable_column_name, "variable")
self.assertEqual(plan.root.unpivot.value_column_name, "value")
diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
index 8413dbaf06d..07cae0fb27d 100644
--- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
+++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
@@ -142,11 +142,6 @@ class DataFrameParityTests(DataFrameTestsMixin, ReusedConnectTestCase):
def test_to_pandas_with_duplicated_column_names(self):
super().test_to_pandas_with_duplicated_column_names()
- # TODO(SPARK-41963): Different exception message in DataFrame.unpivot
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_unpivot_negative(self):
- super().test_unpivot_negative()
-
if __name__ == "__main__":
import unittest
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org