You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2023/02/22 20:53:22 UTC

[spark] branch master updated: [SPARK-42522][CONNECT] Fix DataFrameWriterV2 to find the default source

This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new dbe23c8e88d [SPARK-42522][CONNECT] Fix DataFrameWriterV2 to find the default source
dbe23c8e88d is described below

commit dbe23c8e88d1a2968ae1c17ec9ee3029ef7a348a
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Wed Feb 22 16:53:06 2023 -0400

    [SPARK-42522][CONNECT] Fix DataFrameWriterV2 to find the default source
    
    ### What changes were proposed in this pull request?
    
    Fixes `DataFrameWriterV2` to find the default source.
    
    ### Why are the changes needed?
    
    Currently `DataFrameWriterV2` in Spark Connect doesn't work without the provider with a weird error:
    
    For example:
    
    ```py
    df.writeTo("test_table").create()
    ```
    
    ```
    pyspark.errors.exceptions.connect.SparkConnectGrpcException: (org.apache.spark.SparkClassNotFoundException) [DATA_SOURCE_NOT_FOUND] Failed to find the data source: . Please find packages at `https://spark.apache.org/third-party-projects.html`.
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    Users will be able to use `DataFrameWriterV2` without provider as same as PySpark.
    
    ### How was this patch tested?
    
    Added some tests.
    
    Closes #40109 from ueshin/issues/SPARK-42522/writer_v2.
    
    Authored-by: Takuya UESHIN <ue...@databricks.com>
    Signed-off-by: Herman van Hovell <he...@databricks.com>
---
 .../src/main/protobuf/spark/connect/commands.proto       |  2 +-
 .../spark/sql/connect/planner/SparkConnectPlanner.scala  |  6 +++---
 python/pyspark/sql/connect/proto/commands_pb2.py         | 12 ++++++------
 python/pyspark/sql/connect/proto/commands_pb2.pyi        | 16 ++++++++++++++--
 python/pyspark/sql/tests/test_readwriter.py              | 12 ++++++++++++
 5 files changed, 36 insertions(+), 12 deletions(-)

diff --git a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
index 7567b0e3d7c..1f2f473a050 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
@@ -128,7 +128,7 @@ message WriteOperationV2 {
 
   // (Optional) A provider for the underlying output data source. Spark's default catalog supports
   // "parquet", "json", etc.
-  string provider = 3;
+  optional string provider = 3;
 
   // (Optional) List of columns for partitioning for output table created by `create`,
   // `createOrReplace`, or `replace`
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index a14d3632d28..268bf02fad9 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1614,7 +1614,7 @@ class SparkConnectPlanner(val session: SparkSession) {
 
     writeOperation.getMode match {
       case proto.WriteOperationV2.Mode.MODE_CREATE =>
-        if (writeOperation.getProvider != null) {
+        if (writeOperation.hasProvider) {
           w.using(writeOperation.getProvider).create()
         } else {
           w.create()
@@ -1626,13 +1626,13 @@ class SparkConnectPlanner(val session: SparkSession) {
       case proto.WriteOperationV2.Mode.MODE_APPEND =>
         w.append()
       case proto.WriteOperationV2.Mode.MODE_REPLACE =>
-        if (writeOperation.getProvider != null) {
+        if (writeOperation.hasProvider) {
           w.using(writeOperation.getProvider).replace()
         } else {
           w.replace()
         }
       case proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE =>
-        if (writeOperation.getProvider != null) {
+        if (writeOperation.hasProvider) {
           w.using(writeOperation.getProvider).createOrReplace()
         } else {
           w.createOrReplace()
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py b/python/pyspark/sql/connect/proto/commands_pb2.py
index faa7dd65e2e..c8ade1ea81b 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.py
+++ b/python/pyspark/sql/connect/proto/commands_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xab\x03\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x1 [...]
+    b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xab\x03\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x1 [...]
 )
 
 
@@ -177,11 +177,11 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _WRITEOPERATION_SAVEMODE._serialized_start = 1639
     _WRITEOPERATION_SAVEMODE._serialized_end = 1776
     _WRITEOPERATIONV2._serialized_start = 1803
-    _WRITEOPERATIONV2._serialized_end = 2598
+    _WRITEOPERATIONV2._serialized_end = 2616
     _WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 1224
     _WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 1282
-    _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 2370
-    _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2436
-    _WRITEOPERATIONV2_MODE._serialized_start = 2439
-    _WRITEOPERATIONV2_MODE._serialized_end = 2598
+    _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 2375
+    _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2441
+    _WRITEOPERATIONV2_MODE._serialized_start = 2444
+    _WRITEOPERATIONV2_MODE._serialized_end = 2603
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.pyi b/python/pyspark/sql/connect/proto/commands_pb2.pyi
index c102624ca44..fb767ead329 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/commands_pb2.pyi
@@ -506,7 +506,7 @@ class WriteOperationV2(google.protobuf.message.Message):
         *,
         input: pyspark.sql.connect.proto.relations_pb2.Relation | None = ...,
         table_name: builtins.str = ...,
-        provider: builtins.str = ...,
+        provider: builtins.str | None = ...,
         partitioning_columns: collections.abc.Iterable[
             pyspark.sql.connect.proto.expressions_pb2.Expression
         ]
@@ -519,12 +519,21 @@ class WriteOperationV2(google.protobuf.message.Message):
     def HasField(
         self,
         field_name: typing_extensions.Literal[
-            "input", b"input", "overwrite_condition", b"overwrite_condition"
+            "_provider",
+            b"_provider",
+            "input",
+            b"input",
+            "overwrite_condition",
+            b"overwrite_condition",
+            "provider",
+            b"provider",
         ],
     ) -> builtins.bool: ...
     def ClearField(
         self,
         field_name: typing_extensions.Literal[
+            "_provider",
+            b"_provider",
             "input",
             b"input",
             "mode",
@@ -543,5 +552,8 @@ class WriteOperationV2(google.protobuf.message.Message):
             b"table_properties",
         ],
     ) -> None: ...
+    def WhichOneof(
+        self, oneof_group: typing_extensions.Literal["_provider", b"_provider"]
+    ) -> typing_extensions.Literal["provider"] | None: ...
 
 global___WriteOperationV2 = WriteOperationV2
diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py
index 9cd3e613667..7f9b5e61051 100644
--- a/python/pyspark/sql/tests/test_readwriter.py
+++ b/python/pyspark/sql/tests/test_readwriter.py
@@ -19,6 +19,7 @@ import os
 import shutil
 import tempfile
 
+from pyspark.errors import AnalysisException
 from pyspark.sql.functions import col
 from pyspark.sql.readwriter import DataFrameWriterV2
 from pyspark.sql.types import StructType, StructField, StringType
@@ -215,6 +216,17 @@ class ReadwriterV2TestsMixin:
         self.assertIsInstance(writer.partitionedBy(bucket(11, col("id"))), tpe)
         self.assertIsInstance(writer.partitionedBy(bucket(3, "id"), hours(col("ts"))), tpe)
 
+    def test_create(self):
+        df = self.df
+        with self.table("test_table"):
+            df.writeTo("test_table").using("parquet").create()
+            self.assertEqual(100, self.spark.sql("select * from test_table").count())
+
+    def test_create_without_provider(self):
+        df = self.df
+        with self.assertRaisesRegex(AnalysisException, "Hive support is required"):
+            df.writeTo("test_table").create()
+
 
 class ReadwriterTests(ReadwriterTestsMixin, ReusedSQLTestCase):
     pass


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org