You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2023/02/22 20:53:22 UTC
[spark] branch master updated: [SPARK-42522][CONNECT] Fix DataFrameWriterV2 to find the default source
This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new dbe23c8e88d [SPARK-42522][CONNECT] Fix DataFrameWriterV2 to find the default source
dbe23c8e88d is described below
commit dbe23c8e88d1a2968ae1c17ec9ee3029ef7a348a
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Wed Feb 22 16:53:06 2023 -0400
[SPARK-42522][CONNECT] Fix DataFrameWriterV2 to find the default source
### What changes were proposed in this pull request?
Fixes `DataFrameWriterV2` to find the default source.
### Why are the changes needed?
Currently `DataFrameWriterV2` in Spark Connect doesn't work without the provider with a weird error:
For example:
```py
df.writeTo("test_table").create()
```
```
pyspark.errors.exceptions.connect.SparkConnectGrpcException: (org.apache.spark.SparkClassNotFoundException) [DATA_SOURCE_NOT_FOUND] Failed to find the data source: . Please find packages at `https://spark.apache.org/third-party-projects.html`.
```
### Does this PR introduce _any_ user-facing change?
Users will be able to use `DataFrameWriterV2` without provider as same as PySpark.
### How was this patch tested?
Added some tests.
Closes #40109 from ueshin/issues/SPARK-42522/writer_v2.
Authored-by: Takuya UESHIN <ue...@databricks.com>
Signed-off-by: Herman van Hovell <he...@databricks.com>
---
.../src/main/protobuf/spark/connect/commands.proto | 2 +-
.../spark/sql/connect/planner/SparkConnectPlanner.scala | 6 +++---
python/pyspark/sql/connect/proto/commands_pb2.py | 12 ++++++------
python/pyspark/sql/connect/proto/commands_pb2.pyi | 16 ++++++++++++++--
python/pyspark/sql/tests/test_readwriter.py | 12 ++++++++++++
5 files changed, 36 insertions(+), 12 deletions(-)
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
index 7567b0e3d7c..1f2f473a050 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
@@ -128,7 +128,7 @@ message WriteOperationV2 {
// (Optional) A provider for the underlying output data source. Spark's default catalog supports
// "parquet", "json", etc.
- string provider = 3;
+ optional string provider = 3;
// (Optional) List of columns for partitioning for output table created by `create`,
// `createOrReplace`, or `replace`
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index a14d3632d28..268bf02fad9 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1614,7 +1614,7 @@ class SparkConnectPlanner(val session: SparkSession) {
writeOperation.getMode match {
case proto.WriteOperationV2.Mode.MODE_CREATE =>
- if (writeOperation.getProvider != null) {
+ if (writeOperation.hasProvider) {
w.using(writeOperation.getProvider).create()
} else {
w.create()
@@ -1626,13 +1626,13 @@ class SparkConnectPlanner(val session: SparkSession) {
case proto.WriteOperationV2.Mode.MODE_APPEND =>
w.append()
case proto.WriteOperationV2.Mode.MODE_REPLACE =>
- if (writeOperation.getProvider != null) {
+ if (writeOperation.hasProvider) {
w.using(writeOperation.getProvider).replace()
} else {
w.replace()
}
case proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE =>
- if (writeOperation.getProvider != null) {
+ if (writeOperation.hasProvider) {
w.using(writeOperation.getProvider).createOrReplace()
} else {
w.createOrReplace()
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py b/python/pyspark/sql/connect/proto/commands_pb2.py
index faa7dd65e2e..c8ade1ea81b 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.py
+++ b/python/pyspark/sql/connect/proto/commands_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xab\x03\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x1 [...]
+ b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xab\x03\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x1 [...]
)
@@ -177,11 +177,11 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_WRITEOPERATION_SAVEMODE._serialized_start = 1639
_WRITEOPERATION_SAVEMODE._serialized_end = 1776
_WRITEOPERATIONV2._serialized_start = 1803
- _WRITEOPERATIONV2._serialized_end = 2598
+ _WRITEOPERATIONV2._serialized_end = 2616
_WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 1224
_WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 1282
- _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 2370
- _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2436
- _WRITEOPERATIONV2_MODE._serialized_start = 2439
- _WRITEOPERATIONV2_MODE._serialized_end = 2598
+ _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 2375
+ _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2441
+ _WRITEOPERATIONV2_MODE._serialized_start = 2444
+ _WRITEOPERATIONV2_MODE._serialized_end = 2603
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.pyi b/python/pyspark/sql/connect/proto/commands_pb2.pyi
index c102624ca44..fb767ead329 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/commands_pb2.pyi
@@ -506,7 +506,7 @@ class WriteOperationV2(google.protobuf.message.Message):
*,
input: pyspark.sql.connect.proto.relations_pb2.Relation | None = ...,
table_name: builtins.str = ...,
- provider: builtins.str = ...,
+ provider: builtins.str | None = ...,
partitioning_columns: collections.abc.Iterable[
pyspark.sql.connect.proto.expressions_pb2.Expression
]
@@ -519,12 +519,21 @@ class WriteOperationV2(google.protobuf.message.Message):
def HasField(
self,
field_name: typing_extensions.Literal[
- "input", b"input", "overwrite_condition", b"overwrite_condition"
+ "_provider",
+ b"_provider",
+ "input",
+ b"input",
+ "overwrite_condition",
+ b"overwrite_condition",
+ "provider",
+ b"provider",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
+ "_provider",
+ b"_provider",
"input",
b"input",
"mode",
@@ -543,5 +552,8 @@ class WriteOperationV2(google.protobuf.message.Message):
b"table_properties",
],
) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_provider", b"_provider"]
+ ) -> typing_extensions.Literal["provider"] | None: ...
global___WriteOperationV2 = WriteOperationV2
diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py
index 9cd3e613667..7f9b5e61051 100644
--- a/python/pyspark/sql/tests/test_readwriter.py
+++ b/python/pyspark/sql/tests/test_readwriter.py
@@ -19,6 +19,7 @@ import os
import shutil
import tempfile
+from pyspark.errors import AnalysisException
from pyspark.sql.functions import col
from pyspark.sql.readwriter import DataFrameWriterV2
from pyspark.sql.types import StructType, StructField, StringType
@@ -215,6 +216,17 @@ class ReadwriterV2TestsMixin:
self.assertIsInstance(writer.partitionedBy(bucket(11, col("id"))), tpe)
self.assertIsInstance(writer.partitionedBy(bucket(3, "id"), hours(col("ts"))), tpe)
+ def test_create(self):
+ df = self.df
+ with self.table("test_table"):
+ df.writeTo("test_table").using("parquet").create()
+ self.assertEqual(100, self.spark.sql("select * from test_table").count())
+
+ def test_create_without_provider(self):
+ df = self.df
+ with self.assertRaisesRegex(AnalysisException, "Hive support is required"):
+ df.writeTo("test_table").create()
+
class ReadwriterTests(ReadwriterTestsMixin, ReusedSQLTestCase):
pass
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org