You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2023/02/25 18:14:29 UTC

[spark] branch branch-3.4 updated: [SPARK-42570][CONNECT][PYTHON] Fix DataFrameReader to use the default source

This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 8d2a1c496db [SPARK-42570][CONNECT][PYTHON] Fix DataFrameReader to use the default source
8d2a1c496db is described below

commit 8d2a1c496dbb05e5f390a4fb2d7481ac7f6a868f
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Sat Feb 25 14:14:01 2023 -0400

    [SPARK-42570][CONNECT][PYTHON] Fix DataFrameReader to use the default source
    
    ### What changes were proposed in this pull request?
    
    Fixes `DataFrameReader` to use the default source.
    
    ### Why are the changes needed?
    
    ```py
    spark.read.load(path)
    ```
    
    should work and use the default source without specifying the format.
    
    ### Does this PR introduce _any_ user-facing change?
    
    The `format` doesn't need to be specified.
    
    ### How was this patch tested?
    
    Enabled related tests.
    
    Closes #40166 from ueshin/issues/SPARK-42570/reader.
    
    Authored-by: Takuya UESHIN <ue...@databricks.com>
    Signed-off-by: Herman van Hovell <he...@databricks.com>
    (cherry picked from commit ad35f35f12f715c276d216d621be583a6a44111a)
    Signed-off-by: Herman van Hovell <he...@databricks.com>
---
 .../main/protobuf/spark/connect/relations.proto    |   6 +-
 .../sql/connect/planner/SparkConnectPlanner.scala  |   7 +-
 .../connect/planner/SparkConnectPlannerSuite.scala |  12 --
 python/pyspark/sql/connect/plan.py                 |   8 +-
 python/pyspark/sql/connect/proto/relations_pb2.py  | 186 ++++++++++-----------
 python/pyspark/sql/connect/proto/relations_pb2.pyi |  26 ++-
 python/pyspark/sql/connect/readwriter.py           |   2 +-
 .../sql/tests/connect/test_parity_readwriter.py    |  10 +-
 python/pyspark/sql/tests/test_readwriter.py        | 126 +++++++-------
 9 files changed, 193 insertions(+), 190 deletions(-)

diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 4d96b6b0c7e..2221b4e3982 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -122,8 +122,10 @@ message Read {
   }
 
   message DataSource {
-    // (Required) Supported formats include: parquet, orc, text, json, parquet, csv, avro.
-    string format = 1;
+    // (Optional) Supported formats include: parquet, orc, text, json, parquet, csv, avro.
+    //
+    // If not set, the value from SQL conf 'spark.sql.sources.default' will be used.
+    optional string format = 1;
 
     // (Optional) If not set, Spark will infer the schema.
     //
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index cc43c1cace3..887379ab80d 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -667,12 +667,11 @@ class SparkConnectPlanner(val session: SparkSession) {
         UnresolvedRelation(multipartIdentifier)
 
       case proto.Read.ReadTypeCase.DATA_SOURCE =>
-        if (rel.getDataSource.getFormat == "") {
-          throw InvalidPlanInput("DataSource requires a format")
-        }
         val localMap = CaseInsensitiveMap[String](rel.getDataSource.getOptionsMap.asScala.toMap)
         val reader = session.read
-        reader.format(rel.getDataSource.getFormat)
+        if (rel.getDataSource.hasFormat) {
+          reader.format(rel.getDataSource.getFormat)
+        }
         localMap.foreach { case (key, value) => reader.option(key, value) }
         if (rel.getDataSource.hasSchema && rel.getDataSource.getSchema.nonEmpty) {
 
diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index 3e4a0f94ea2..83056c27729 100644
--- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -332,18 +332,6 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
     assert(res.nodeName == "Aggregate")
   }
 
-  test("Invalid DataSource") {
-    val dataSource = proto.Read.DataSource.newBuilder()
-
-    val e = intercept[InvalidPlanInput](
-      transform(
-        proto.Relation
-          .newBuilder()
-          .setRead(proto.Read.newBuilder().setDataSource(dataSource))
-          .build()))
-    assert(e.getMessage.contains("DataSource requires a format"))
-  }
-
   test("Test invalid deduplicate") {
     val deduplicate = proto.Deduplicate
       .newBuilder()
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index badbb9871ed..857cca64c6f 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -255,15 +255,14 @@ class DataSource(LogicalPlan):
 
     def __init__(
         self,
-        format: str,
+        format: Optional[str] = None,
         schema: Optional[str] = None,
         options: Optional[Mapping[str, str]] = None,
         paths: Optional[List[str]] = None,
     ) -> None:
         super().__init__(None)
 
-        assert isinstance(format, str) and format != ""
-
+        assert format is None or isinstance(format, str)
         assert schema is None or isinstance(schema, str)
 
         if options is not None:
@@ -282,7 +281,8 @@ class DataSource(LogicalPlan):
 
     def plan(self, session: "SparkConnectClient") -> proto.Relation:
         plan = self._create_proto_relation()
-        plan.read.data_source.format = self._format
+        if self._format is not None:
+            plan.read.data_source.format = self._format
         if self._schema is not None:
             plan.read.data_source.schema = self._schema
         if self._options is not None and len(self._options) > 0:
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py
index 3afdf61e681..c6d9616e44c 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as spark_dot_connect_dot_catal
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xb1\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+    b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xb1\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
 )
 
 
@@ -657,99 +657,99 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _SQL_ARGSENTRY._serialized_start = 2704
     _SQL_ARGSENTRY._serialized_end = 2759
     _READ._serialized_start = 2762
-    _READ._serialized_end = 3210
+    _READ._serialized_end = 3226
     _READ_NAMEDTABLE._serialized_start = 2904
     _READ_NAMEDTABLE._serialized_end = 2965
     _READ_DATASOURCE._serialized_start = 2968
-    _READ_DATASOURCE._serialized_end = 3197
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3128
-    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3186
-    _PROJECT._serialized_start = 3212
-    _PROJECT._serialized_end = 3329
-    _FILTER._serialized_start = 3331
-    _FILTER._serialized_end = 3443
-    _JOIN._serialized_start = 3446
-    _JOIN._serialized_end = 3917
-    _JOIN_JOINTYPE._serialized_start = 3709
-    _JOIN_JOINTYPE._serialized_end = 3917
-    _SETOPERATION._serialized_start = 3920
-    _SETOPERATION._serialized_end = 4399
-    _SETOPERATION_SETOPTYPE._serialized_start = 4236
-    _SETOPERATION_SETOPTYPE._serialized_end = 4350
-    _LIMIT._serialized_start = 4401
-    _LIMIT._serialized_end = 4477
-    _OFFSET._serialized_start = 4479
-    _OFFSET._serialized_end = 4558
-    _TAIL._serialized_start = 4560
-    _TAIL._serialized_end = 4635
-    _AGGREGATE._serialized_start = 4638
-    _AGGREGATE._serialized_end = 5220
-    _AGGREGATE_PIVOT._serialized_start = 4977
-    _AGGREGATE_PIVOT._serialized_end = 5088
-    _AGGREGATE_GROUPTYPE._serialized_start = 5091
-    _AGGREGATE_GROUPTYPE._serialized_end = 5220
-    _SORT._serialized_start = 5223
-    _SORT._serialized_end = 5383
-    _DROP._serialized_start = 5385
-    _DROP._serialized_end = 5485
-    _DEDUPLICATE._serialized_start = 5488
-    _DEDUPLICATE._serialized_end = 5659
-    _LOCALRELATION._serialized_start = 5661
-    _LOCALRELATION._serialized_end = 5750
-    _SAMPLE._serialized_start = 5753
-    _SAMPLE._serialized_end = 6026
-    _RANGE._serialized_start = 6029
-    _RANGE._serialized_end = 6174
-    _SUBQUERYALIAS._serialized_start = 6176
-    _SUBQUERYALIAS._serialized_end = 6290
-    _REPARTITION._serialized_start = 6293
-    _REPARTITION._serialized_end = 6435
-    _SHOWSTRING._serialized_start = 6438
-    _SHOWSTRING._serialized_end = 6580
-    _STATSUMMARY._serialized_start = 6582
-    _STATSUMMARY._serialized_end = 6674
-    _STATDESCRIBE._serialized_start = 6676
-    _STATDESCRIBE._serialized_end = 6757
-    _STATCROSSTAB._serialized_start = 6759
-    _STATCROSSTAB._serialized_end = 6860
-    _STATCOV._serialized_start = 6862
-    _STATCOV._serialized_end = 6958
-    _STATCORR._serialized_start = 6961
-    _STATCORR._serialized_end = 7098
-    _STATAPPROXQUANTILE._serialized_start = 7101
-    _STATAPPROXQUANTILE._serialized_end = 7265
-    _STATFREQITEMS._serialized_start = 7267
-    _STATFREQITEMS._serialized_end = 7392
-    _STATSAMPLEBY._serialized_start = 7395
-    _STATSAMPLEBY._serialized_end = 7704
-    _STATSAMPLEBY_FRACTION._serialized_start = 7596
-    _STATSAMPLEBY_FRACTION._serialized_end = 7695
-    _NAFILL._serialized_start = 7707
-    _NAFILL._serialized_end = 7841
-    _NADROP._serialized_start = 7844
-    _NADROP._serialized_end = 7978
-    _NAREPLACE._serialized_start = 7981
-    _NAREPLACE._serialized_end = 8277
-    _NAREPLACE_REPLACEMENT._serialized_start = 8136
-    _NAREPLACE_REPLACEMENT._serialized_end = 8277
-    _TODF._serialized_start = 8279
-    _TODF._serialized_end = 8367
-    _WITHCOLUMNSRENAMED._serialized_start = 8370
-    _WITHCOLUMNSRENAMED._serialized_end = 8609
-    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8542
-    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8609
-    _WITHCOLUMNS._serialized_start = 8611
-    _WITHCOLUMNS._serialized_end = 8730
-    _HINT._serialized_start = 8733
-    _HINT._serialized_end = 8865
-    _UNPIVOT._serialized_start = 8868
-    _UNPIVOT._serialized_end = 9195
-    _UNPIVOT_VALUES._serialized_start = 9125
-    _UNPIVOT_VALUES._serialized_end = 9184
-    _TOSCHEMA._serialized_start = 9197
-    _TOSCHEMA._serialized_end = 9303
-    _REPARTITIONBYEXPRESSION._serialized_start = 9306
-    _REPARTITIONBYEXPRESSION._serialized_end = 9509
-    _FRAMEMAP._serialized_start = 9511
-    _FRAMEMAP._serialized_end = 9636
+    _READ_DATASOURCE._serialized_end = 3213
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3133
+    _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3191
+    _PROJECT._serialized_start = 3228
+    _PROJECT._serialized_end = 3345
+    _FILTER._serialized_start = 3347
+    _FILTER._serialized_end = 3459
+    _JOIN._serialized_start = 3462
+    _JOIN._serialized_end = 3933
+    _JOIN_JOINTYPE._serialized_start = 3725
+    _JOIN_JOINTYPE._serialized_end = 3933
+    _SETOPERATION._serialized_start = 3936
+    _SETOPERATION._serialized_end = 4415
+    _SETOPERATION_SETOPTYPE._serialized_start = 4252
+    _SETOPERATION_SETOPTYPE._serialized_end = 4366
+    _LIMIT._serialized_start = 4417
+    _LIMIT._serialized_end = 4493
+    _OFFSET._serialized_start = 4495
+    _OFFSET._serialized_end = 4574
+    _TAIL._serialized_start = 4576
+    _TAIL._serialized_end = 4651
+    _AGGREGATE._serialized_start = 4654
+    _AGGREGATE._serialized_end = 5236
+    _AGGREGATE_PIVOT._serialized_start = 4993
+    _AGGREGATE_PIVOT._serialized_end = 5104
+    _AGGREGATE_GROUPTYPE._serialized_start = 5107
+    _AGGREGATE_GROUPTYPE._serialized_end = 5236
+    _SORT._serialized_start = 5239
+    _SORT._serialized_end = 5399
+    _DROP._serialized_start = 5401
+    _DROP._serialized_end = 5501
+    _DEDUPLICATE._serialized_start = 5504
+    _DEDUPLICATE._serialized_end = 5675
+    _LOCALRELATION._serialized_start = 5677
+    _LOCALRELATION._serialized_end = 5766
+    _SAMPLE._serialized_start = 5769
+    _SAMPLE._serialized_end = 6042
+    _RANGE._serialized_start = 6045
+    _RANGE._serialized_end = 6190
+    _SUBQUERYALIAS._serialized_start = 6192
+    _SUBQUERYALIAS._serialized_end = 6306
+    _REPARTITION._serialized_start = 6309
+    _REPARTITION._serialized_end = 6451
+    _SHOWSTRING._serialized_start = 6454
+    _SHOWSTRING._serialized_end = 6596
+    _STATSUMMARY._serialized_start = 6598
+    _STATSUMMARY._serialized_end = 6690
+    _STATDESCRIBE._serialized_start = 6692
+    _STATDESCRIBE._serialized_end = 6773
+    _STATCROSSTAB._serialized_start = 6775
+    _STATCROSSTAB._serialized_end = 6876
+    _STATCOV._serialized_start = 6878
+    _STATCOV._serialized_end = 6974
+    _STATCORR._serialized_start = 6977
+    _STATCORR._serialized_end = 7114
+    _STATAPPROXQUANTILE._serialized_start = 7117
+    _STATAPPROXQUANTILE._serialized_end = 7281
+    _STATFREQITEMS._serialized_start = 7283
+    _STATFREQITEMS._serialized_end = 7408
+    _STATSAMPLEBY._serialized_start = 7411
+    _STATSAMPLEBY._serialized_end = 7720
+    _STATSAMPLEBY_FRACTION._serialized_start = 7612
+    _STATSAMPLEBY_FRACTION._serialized_end = 7711
+    _NAFILL._serialized_start = 7723
+    _NAFILL._serialized_end = 7857
+    _NADROP._serialized_start = 7860
+    _NADROP._serialized_end = 7994
+    _NAREPLACE._serialized_start = 7997
+    _NAREPLACE._serialized_end = 8293
+    _NAREPLACE_REPLACEMENT._serialized_start = 8152
+    _NAREPLACE_REPLACEMENT._serialized_end = 8293
+    _TODF._serialized_start = 8295
+    _TODF._serialized_end = 8383
+    _WITHCOLUMNSRENAMED._serialized_start = 8386
+    _WITHCOLUMNSRENAMED._serialized_end = 8625
+    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8558
+    _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8625
+    _WITHCOLUMNS._serialized_start = 8627
+    _WITHCOLUMNS._serialized_end = 8746
+    _HINT._serialized_start = 8749
+    _HINT._serialized_end = 8881
+    _UNPIVOT._serialized_start = 8884
+    _UNPIVOT._serialized_end = 9211
+    _UNPIVOT_VALUES._serialized_start = 9141
+    _UNPIVOT_VALUES._serialized_end = 9200
+    _TOSCHEMA._serialized_start = 9213
+    _TOSCHEMA._serialized_end = 9319
+    _REPARTITIONBYEXPRESSION._serialized_start = 9322
+    _REPARTITIONBYEXPRESSION._serialized_end = 9525
+    _FRAMEMAP._serialized_start = 9527
+    _FRAMEMAP._serialized_end = 9652
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 3f3b9f4c5b0..27fd07a192e 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -602,7 +602,10 @@ class Read(google.protobuf.message.Message):
         OPTIONS_FIELD_NUMBER: builtins.int
         PATHS_FIELD_NUMBER: builtins.int
         format: builtins.str
-        """(Required) Supported formats include: parquet, orc, text, json, parquet, csv, avro."""
+        """(Optional) Supported formats include: parquet, orc, text, json, parquet, csv, avro.
+
+        If not set, the value from SQL conf 'spark.sql.sources.default' will be used.
+        """
         schema: builtins.str
         """(Optional) If not set, Spark will infer the schema.
 
@@ -624,17 +627,29 @@ class Read(google.protobuf.message.Message):
         def __init__(
             self,
             *,
-            format: builtins.str = ...,
+            format: builtins.str | None = ...,
             schema: builtins.str | None = ...,
             options: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
             paths: collections.abc.Iterable[builtins.str] | None = ...,
         ) -> None: ...
         def HasField(
-            self, field_name: typing_extensions.Literal["_schema", b"_schema", "schema", b"schema"]
+            self,
+            field_name: typing_extensions.Literal[
+                "_format",
+                b"_format",
+                "_schema",
+                b"_schema",
+                "format",
+                b"format",
+                "schema",
+                b"schema",
+            ],
         ) -> builtins.bool: ...
         def ClearField(
             self,
             field_name: typing_extensions.Literal[
+                "_format",
+                b"_format",
                 "_schema",
                 b"_schema",
                 "format",
@@ -647,6 +662,11 @@ class Read(google.protobuf.message.Message):
                 b"schema",
             ],
         ) -> None: ...
+        @typing.overload
+        def WhichOneof(
+            self, oneof_group: typing_extensions.Literal["_format", b"_format"]
+        ) -> typing_extensions.Literal["format"] | None: ...
+        @typing.overload
         def WhichOneof(
             self, oneof_group: typing_extensions.Literal["_schema", b"_schema"]
         ) -> typing_extensions.Literal["schema"] | None: ...
diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py
index 292e58b3552..9c9c79cb6eb 100644
--- a/python/pyspark/sql/connect/readwriter.py
+++ b/python/pyspark/sql/connect/readwriter.py
@@ -63,7 +63,7 @@ class DataFrameReader(OptionUtils):
 
     def __init__(self, client: "SparkSession"):
         self._client = client
-        self._format = ""
+        self._format: Optional[str] = None
         self._schema = ""
         self._options: Dict[str, str] = {}
 
diff --git a/python/pyspark/sql/tests/connect/test_parity_readwriter.py b/python/pyspark/sql/tests/connect/test_parity_readwriter.py
index bf77043ef38..2fa3f79a92f 100644
--- a/python/pyspark/sql/tests/connect/test_parity_readwriter.py
+++ b/python/pyspark/sql/tests/connect/test_parity_readwriter.py
@@ -22,15 +22,7 @@ from pyspark.testing.connectutils import ReusedConnectTestCase
 
 
 class ReadwriterParityTests(ReadwriterTestsMixin, ReusedConnectTestCase):
-    # TODO(SPARK-41834): Implement SparkSession.conf
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_save_and_load(self):
-        super().test_save_and_load()
-
-    # TODO(SPARK-41834): Implement SparkSession.conf
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_save_and_load_builder(self):
-        super().test_save_and_load_builder()
+    pass
 
 
 class ReadwriterV2ParityTests(ReadwriterV2TestsMixin, ReusedConnectTestCase):
diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py
index 7f9b5e61051..21c66284ace 100644
--- a/python/pyspark/sql/tests/test_readwriter.py
+++ b/python/pyspark/sql/tests/test_readwriter.py
@@ -31,75 +31,77 @@ class ReadwriterTestsMixin:
         df = self.df
         tmpPath = tempfile.mkdtemp()
         shutil.rmtree(tmpPath)
-        df.write.json(tmpPath)
-        actual = self.spark.read.json(tmpPath)
-        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
-        schema = StructType([StructField("value", StringType(), True)])
-        actual = self.spark.read.json(tmpPath, schema)
-        self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
-
-        df.write.json(tmpPath, "overwrite")
-        actual = self.spark.read.json(tmpPath)
-        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
-        df.write.save(
-            format="json",
-            mode="overwrite",
-            path=tmpPath,
-            noUse="this options will not be used in save.",
-        )
-        actual = self.spark.read.load(
-            format="json", path=tmpPath, noUse="this options will not be used in load."
-        )
-        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
-        defaultDataSourceName = self.spark.conf.get(
-            "spark.sql.sources.default", "org.apache.spark.sql.parquet"
-        )
-        self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
-        actual = self.spark.read.load(path=tmpPath)
-        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-        self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
+        try:
+            df.write.json(tmpPath)
+            actual = self.spark.read.json(tmpPath)
+            self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+            schema = StructType([StructField("value", StringType(), True)])
+            actual = self.spark.read.json(tmpPath, schema)
+            self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
+
+            df.write.json(tmpPath, "overwrite")
+            actual = self.spark.read.json(tmpPath)
+            self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+            df.write.save(
+                format="json",
+                mode="overwrite",
+                path=tmpPath,
+                noUse="this options will not be used in save.",
+            )
+            actual = self.spark.read.load(
+                format="json", path=tmpPath, noUse="this options will not be used in load."
+            )
+            self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
 
-        csvpath = os.path.join(tempfile.mkdtemp(), "data")
-        df.write.option("quote", None).format("csv").save(csvpath)
+            try:
+                self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json").collect()
+                actual = self.spark.read.load(path=tmpPath)
+                self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+            finally:
+                self.spark.sql("RESET spark.sql.sources.default").collect()
 
-        shutil.rmtree(tmpPath)
+            csvpath = os.path.join(tempfile.mkdtemp(), "data")
+            df.write.option("quote", None).format("csv").save(csvpath)
+        finally:
+            shutil.rmtree(tmpPath)
 
     def test_save_and_load_builder(self):
         df = self.df
         tmpPath = tempfile.mkdtemp()
         shutil.rmtree(tmpPath)
-        df.write.json(tmpPath)
-        actual = self.spark.read.json(tmpPath)
-        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
-        schema = StructType([StructField("value", StringType(), True)])
-        actual = self.spark.read.json(tmpPath, schema)
-        self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
-
-        df.write.mode("overwrite").json(tmpPath)
-        actual = self.spark.read.json(tmpPath)
-        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
-        df.write.mode("overwrite").options(noUse="this options will not be used in save.").option(
-            "noUse", "this option will not be used in save."
-        ).format("json").save(path=tmpPath)
-        actual = self.spark.read.format("json").load(
-            path=tmpPath, noUse="this options will not be used in load."
-        )
-        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
-        defaultDataSourceName = self.spark.conf.get(
-            "spark.sql.sources.default", "org.apache.spark.sql.parquet"
-        )
-        self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
-        actual = self.spark.read.load(path=tmpPath)
-        self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-        self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
-
-        shutil.rmtree(tmpPath)
+        try:
+            df.write.json(tmpPath)
+            actual = self.spark.read.json(tmpPath)
+            self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+            schema = StructType([StructField("value", StringType(), True)])
+            actual = self.spark.read.json(tmpPath, schema)
+            self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
+
+            df.write.mode("overwrite").json(tmpPath)
+            actual = self.spark.read.json(tmpPath)
+            self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+            df.write.mode("overwrite").options(
+                noUse="this options will not be used in save."
+            ).option("noUse", "this option will not be used in save.").format("json").save(
+                path=tmpPath
+            )
+            actual = self.spark.read.format("json").load(
+                path=tmpPath, noUse="this options will not be used in load."
+            )
+            self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+            try:
+                self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json").collect()
+                actual = self.spark.read.load(path=tmpPath)
+                self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+            finally:
+                self.spark.sql("RESET spark.sql.sources.default").collect()
+        finally:
+            shutil.rmtree(tmpPath)
 
     def test_bucketed_write(self):
         data = [


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org