You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2023/02/25 18:14:15 UTC
[spark] branch master updated: [SPARK-42570][CONNECT][PYTHON] Fix DataFrameReader to use the default source
This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new ad35f35f12f [SPARK-42570][CONNECT][PYTHON] Fix DataFrameReader to use the default source
ad35f35f12f is described below
commit ad35f35f12f715c276d216d621be583a6a44111a
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Sat Feb 25 14:14:01 2023 -0400
[SPARK-42570][CONNECT][PYTHON] Fix DataFrameReader to use the default source
### What changes were proposed in this pull request?
Fixes `DataFrameReader` to use the default source.
### Why are the changes needed?
```py
spark.read.load(path)
```
should work and use the default source without specifying the format.
### Does this PR introduce _any_ user-facing change?
The `format` doesn't need to be specified.
### How was this patch tested?
Enabled related tests.
Closes #40166 from ueshin/issues/SPARK-42570/reader.
Authored-by: Takuya UESHIN <ue...@databricks.com>
Signed-off-by: Herman van Hovell <he...@databricks.com>
---
.../main/protobuf/spark/connect/relations.proto | 6 +-
.../sql/connect/planner/SparkConnectPlanner.scala | 7 +-
.../connect/planner/SparkConnectPlannerSuite.scala | 12 --
python/pyspark/sql/connect/plan.py | 8 +-
python/pyspark/sql/connect/proto/relations_pb2.py | 186 ++++++++++-----------
python/pyspark/sql/connect/proto/relations_pb2.pyi | 26 ++-
python/pyspark/sql/connect/readwriter.py | 2 +-
.../sql/tests/connect/test_parity_readwriter.py | 10 +-
python/pyspark/sql/tests/test_readwriter.py | 126 +++++++-------
9 files changed, 193 insertions(+), 190 deletions(-)
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 4d96b6b0c7e..2221b4e3982 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -122,8 +122,10 @@ message Read {
}
message DataSource {
- // (Required) Supported formats include: parquet, orc, text, json, parquet, csv, avro.
- string format = 1;
+ // (Optional) Supported formats include: parquet, orc, text, json, parquet, csv, avro.
+ //
+ // If not set, the value from SQL conf 'spark.sql.sources.default' will be used.
+ optional string format = 1;
// (Optional) If not set, Spark will infer the schema.
//
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index cc43c1cace3..887379ab80d 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -667,12 +667,11 @@ class SparkConnectPlanner(val session: SparkSession) {
UnresolvedRelation(multipartIdentifier)
case proto.Read.ReadTypeCase.DATA_SOURCE =>
- if (rel.getDataSource.getFormat == "") {
- throw InvalidPlanInput("DataSource requires a format")
- }
val localMap = CaseInsensitiveMap[String](rel.getDataSource.getOptionsMap.asScala.toMap)
val reader = session.read
- reader.format(rel.getDataSource.getFormat)
+ if (rel.getDataSource.hasFormat) {
+ reader.format(rel.getDataSource.getFormat)
+ }
localMap.foreach { case (key, value) => reader.option(key, value) }
if (rel.getDataSource.hasSchema && rel.getDataSource.getSchema.nonEmpty) {
diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index 3e4a0f94ea2..83056c27729 100644
--- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -332,18 +332,6 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
assert(res.nodeName == "Aggregate")
}
- test("Invalid DataSource") {
- val dataSource = proto.Read.DataSource.newBuilder()
-
- val e = intercept[InvalidPlanInput](
- transform(
- proto.Relation
- .newBuilder()
- .setRead(proto.Read.newBuilder().setDataSource(dataSource))
- .build()))
- assert(e.getMessage.contains("DataSource requires a format"))
- }
-
test("Test invalid deduplicate") {
val deduplicate = proto.Deduplicate
.newBuilder()
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index badbb9871ed..857cca64c6f 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -255,15 +255,14 @@ class DataSource(LogicalPlan):
def __init__(
self,
- format: str,
+ format: Optional[str] = None,
schema: Optional[str] = None,
options: Optional[Mapping[str, str]] = None,
paths: Optional[List[str]] = None,
) -> None:
super().__init__(None)
- assert isinstance(format, str) and format != ""
-
+ assert format is None or isinstance(format, str)
assert schema is None or isinstance(schema, str)
if options is not None:
@@ -282,7 +281,8 @@ class DataSource(LogicalPlan):
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
- plan.read.data_source.format = self._format
+ if self._format is not None:
+ plan.read.data_source.format = self._format
if self._schema is not None:
plan.read.data_source.schema = self._schema
if self._options is not None and len(self._options) > 0:
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py
index 3afdf61e681..c6d9616e44c 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as spark_dot_connect_dot_catal
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xb1\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+ b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xb1\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
)
@@ -657,99 +657,99 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_SQL_ARGSENTRY._serialized_start = 2704
_SQL_ARGSENTRY._serialized_end = 2759
_READ._serialized_start = 2762
- _READ._serialized_end = 3210
+ _READ._serialized_end = 3226
_READ_NAMEDTABLE._serialized_start = 2904
_READ_NAMEDTABLE._serialized_end = 2965
_READ_DATASOURCE._serialized_start = 2968
- _READ_DATASOURCE._serialized_end = 3197
- _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3128
- _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3186
- _PROJECT._serialized_start = 3212
- _PROJECT._serialized_end = 3329
- _FILTER._serialized_start = 3331
- _FILTER._serialized_end = 3443
- _JOIN._serialized_start = 3446
- _JOIN._serialized_end = 3917
- _JOIN_JOINTYPE._serialized_start = 3709
- _JOIN_JOINTYPE._serialized_end = 3917
- _SETOPERATION._serialized_start = 3920
- _SETOPERATION._serialized_end = 4399
- _SETOPERATION_SETOPTYPE._serialized_start = 4236
- _SETOPERATION_SETOPTYPE._serialized_end = 4350
- _LIMIT._serialized_start = 4401
- _LIMIT._serialized_end = 4477
- _OFFSET._serialized_start = 4479
- _OFFSET._serialized_end = 4558
- _TAIL._serialized_start = 4560
- _TAIL._serialized_end = 4635
- _AGGREGATE._serialized_start = 4638
- _AGGREGATE._serialized_end = 5220
- _AGGREGATE_PIVOT._serialized_start = 4977
- _AGGREGATE_PIVOT._serialized_end = 5088
- _AGGREGATE_GROUPTYPE._serialized_start = 5091
- _AGGREGATE_GROUPTYPE._serialized_end = 5220
- _SORT._serialized_start = 5223
- _SORT._serialized_end = 5383
- _DROP._serialized_start = 5385
- _DROP._serialized_end = 5485
- _DEDUPLICATE._serialized_start = 5488
- _DEDUPLICATE._serialized_end = 5659
- _LOCALRELATION._serialized_start = 5661
- _LOCALRELATION._serialized_end = 5750
- _SAMPLE._serialized_start = 5753
- _SAMPLE._serialized_end = 6026
- _RANGE._serialized_start = 6029
- _RANGE._serialized_end = 6174
- _SUBQUERYALIAS._serialized_start = 6176
- _SUBQUERYALIAS._serialized_end = 6290
- _REPARTITION._serialized_start = 6293
- _REPARTITION._serialized_end = 6435
- _SHOWSTRING._serialized_start = 6438
- _SHOWSTRING._serialized_end = 6580
- _STATSUMMARY._serialized_start = 6582
- _STATSUMMARY._serialized_end = 6674
- _STATDESCRIBE._serialized_start = 6676
- _STATDESCRIBE._serialized_end = 6757
- _STATCROSSTAB._serialized_start = 6759
- _STATCROSSTAB._serialized_end = 6860
- _STATCOV._serialized_start = 6862
- _STATCOV._serialized_end = 6958
- _STATCORR._serialized_start = 6961
- _STATCORR._serialized_end = 7098
- _STATAPPROXQUANTILE._serialized_start = 7101
- _STATAPPROXQUANTILE._serialized_end = 7265
- _STATFREQITEMS._serialized_start = 7267
- _STATFREQITEMS._serialized_end = 7392
- _STATSAMPLEBY._serialized_start = 7395
- _STATSAMPLEBY._serialized_end = 7704
- _STATSAMPLEBY_FRACTION._serialized_start = 7596
- _STATSAMPLEBY_FRACTION._serialized_end = 7695
- _NAFILL._serialized_start = 7707
- _NAFILL._serialized_end = 7841
- _NADROP._serialized_start = 7844
- _NADROP._serialized_end = 7978
- _NAREPLACE._serialized_start = 7981
- _NAREPLACE._serialized_end = 8277
- _NAREPLACE_REPLACEMENT._serialized_start = 8136
- _NAREPLACE_REPLACEMENT._serialized_end = 8277
- _TODF._serialized_start = 8279
- _TODF._serialized_end = 8367
- _WITHCOLUMNSRENAMED._serialized_start = 8370
- _WITHCOLUMNSRENAMED._serialized_end = 8609
- _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8542
- _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8609
- _WITHCOLUMNS._serialized_start = 8611
- _WITHCOLUMNS._serialized_end = 8730
- _HINT._serialized_start = 8733
- _HINT._serialized_end = 8865
- _UNPIVOT._serialized_start = 8868
- _UNPIVOT._serialized_end = 9195
- _UNPIVOT_VALUES._serialized_start = 9125
- _UNPIVOT_VALUES._serialized_end = 9184
- _TOSCHEMA._serialized_start = 9197
- _TOSCHEMA._serialized_end = 9303
- _REPARTITIONBYEXPRESSION._serialized_start = 9306
- _REPARTITIONBYEXPRESSION._serialized_end = 9509
- _FRAMEMAP._serialized_start = 9511
- _FRAMEMAP._serialized_end = 9636
+ _READ_DATASOURCE._serialized_end = 3213
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3133
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3191
+ _PROJECT._serialized_start = 3228
+ _PROJECT._serialized_end = 3345
+ _FILTER._serialized_start = 3347
+ _FILTER._serialized_end = 3459
+ _JOIN._serialized_start = 3462
+ _JOIN._serialized_end = 3933
+ _JOIN_JOINTYPE._serialized_start = 3725
+ _JOIN_JOINTYPE._serialized_end = 3933
+ _SETOPERATION._serialized_start = 3936
+ _SETOPERATION._serialized_end = 4415
+ _SETOPERATION_SETOPTYPE._serialized_start = 4252
+ _SETOPERATION_SETOPTYPE._serialized_end = 4366
+ _LIMIT._serialized_start = 4417
+ _LIMIT._serialized_end = 4493
+ _OFFSET._serialized_start = 4495
+ _OFFSET._serialized_end = 4574
+ _TAIL._serialized_start = 4576
+ _TAIL._serialized_end = 4651
+ _AGGREGATE._serialized_start = 4654
+ _AGGREGATE._serialized_end = 5236
+ _AGGREGATE_PIVOT._serialized_start = 4993
+ _AGGREGATE_PIVOT._serialized_end = 5104
+ _AGGREGATE_GROUPTYPE._serialized_start = 5107
+ _AGGREGATE_GROUPTYPE._serialized_end = 5236
+ _SORT._serialized_start = 5239
+ _SORT._serialized_end = 5399
+ _DROP._serialized_start = 5401
+ _DROP._serialized_end = 5501
+ _DEDUPLICATE._serialized_start = 5504
+ _DEDUPLICATE._serialized_end = 5675
+ _LOCALRELATION._serialized_start = 5677
+ _LOCALRELATION._serialized_end = 5766
+ _SAMPLE._serialized_start = 5769
+ _SAMPLE._serialized_end = 6042
+ _RANGE._serialized_start = 6045
+ _RANGE._serialized_end = 6190
+ _SUBQUERYALIAS._serialized_start = 6192
+ _SUBQUERYALIAS._serialized_end = 6306
+ _REPARTITION._serialized_start = 6309
+ _REPARTITION._serialized_end = 6451
+ _SHOWSTRING._serialized_start = 6454
+ _SHOWSTRING._serialized_end = 6596
+ _STATSUMMARY._serialized_start = 6598
+ _STATSUMMARY._serialized_end = 6690
+ _STATDESCRIBE._serialized_start = 6692
+ _STATDESCRIBE._serialized_end = 6773
+ _STATCROSSTAB._serialized_start = 6775
+ _STATCROSSTAB._serialized_end = 6876
+ _STATCOV._serialized_start = 6878
+ _STATCOV._serialized_end = 6974
+ _STATCORR._serialized_start = 6977
+ _STATCORR._serialized_end = 7114
+ _STATAPPROXQUANTILE._serialized_start = 7117
+ _STATAPPROXQUANTILE._serialized_end = 7281
+ _STATFREQITEMS._serialized_start = 7283
+ _STATFREQITEMS._serialized_end = 7408
+ _STATSAMPLEBY._serialized_start = 7411
+ _STATSAMPLEBY._serialized_end = 7720
+ _STATSAMPLEBY_FRACTION._serialized_start = 7612
+ _STATSAMPLEBY_FRACTION._serialized_end = 7711
+ _NAFILL._serialized_start = 7723
+ _NAFILL._serialized_end = 7857
+ _NADROP._serialized_start = 7860
+ _NADROP._serialized_end = 7994
+ _NAREPLACE._serialized_start = 7997
+ _NAREPLACE._serialized_end = 8293
+ _NAREPLACE_REPLACEMENT._serialized_start = 8152
+ _NAREPLACE_REPLACEMENT._serialized_end = 8293
+ _TODF._serialized_start = 8295
+ _TODF._serialized_end = 8383
+ _WITHCOLUMNSRENAMED._serialized_start = 8386
+ _WITHCOLUMNSRENAMED._serialized_end = 8625
+ _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8558
+ _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8625
+ _WITHCOLUMNS._serialized_start = 8627
+ _WITHCOLUMNS._serialized_end = 8746
+ _HINT._serialized_start = 8749
+ _HINT._serialized_end = 8881
+ _UNPIVOT._serialized_start = 8884
+ _UNPIVOT._serialized_end = 9211
+ _UNPIVOT_VALUES._serialized_start = 9141
+ _UNPIVOT_VALUES._serialized_end = 9200
+ _TOSCHEMA._serialized_start = 9213
+ _TOSCHEMA._serialized_end = 9319
+ _REPARTITIONBYEXPRESSION._serialized_start = 9322
+ _REPARTITIONBYEXPRESSION._serialized_end = 9525
+ _FRAMEMAP._serialized_start = 9527
+ _FRAMEMAP._serialized_end = 9652
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 3f3b9f4c5b0..27fd07a192e 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -602,7 +602,10 @@ class Read(google.protobuf.message.Message):
OPTIONS_FIELD_NUMBER: builtins.int
PATHS_FIELD_NUMBER: builtins.int
format: builtins.str
- """(Required) Supported formats include: parquet, orc, text, json, parquet, csv, avro."""
+ """(Optional) Supported formats include: parquet, orc, text, json, parquet, csv, avro.
+
+ If not set, the value from SQL conf 'spark.sql.sources.default' will be used.
+ """
schema: builtins.str
"""(Optional) If not set, Spark will infer the schema.
@@ -624,17 +627,29 @@ class Read(google.protobuf.message.Message):
def __init__(
self,
*,
- format: builtins.str = ...,
+ format: builtins.str | None = ...,
schema: builtins.str | None = ...,
options: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
paths: collections.abc.Iterable[builtins.str] | None = ...,
) -> None: ...
def HasField(
- self, field_name: typing_extensions.Literal["_schema", b"_schema", "schema", b"schema"]
+ self,
+ field_name: typing_extensions.Literal[
+ "_format",
+ b"_format",
+ "_schema",
+ b"_schema",
+ "format",
+ b"format",
+ "schema",
+ b"schema",
+ ],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
+ "_format",
+ b"_format",
"_schema",
b"_schema",
"format",
@@ -647,6 +662,11 @@ class Read(google.protobuf.message.Message):
b"schema",
],
) -> None: ...
+ @typing.overload
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_format", b"_format"]
+ ) -> typing_extensions.Literal["format"] | None: ...
+ @typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_schema", b"_schema"]
) -> typing_extensions.Literal["schema"] | None: ...
diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py
index 292e58b3552..9c9c79cb6eb 100644
--- a/python/pyspark/sql/connect/readwriter.py
+++ b/python/pyspark/sql/connect/readwriter.py
@@ -63,7 +63,7 @@ class DataFrameReader(OptionUtils):
def __init__(self, client: "SparkSession"):
self._client = client
- self._format = ""
+ self._format: Optional[str] = None
self._schema = ""
self._options: Dict[str, str] = {}
diff --git a/python/pyspark/sql/tests/connect/test_parity_readwriter.py b/python/pyspark/sql/tests/connect/test_parity_readwriter.py
index bf77043ef38..2fa3f79a92f 100644
--- a/python/pyspark/sql/tests/connect/test_parity_readwriter.py
+++ b/python/pyspark/sql/tests/connect/test_parity_readwriter.py
@@ -22,15 +22,7 @@ from pyspark.testing.connectutils import ReusedConnectTestCase
class ReadwriterParityTests(ReadwriterTestsMixin, ReusedConnectTestCase):
- # TODO(SPARK-41834): Implement SparkSession.conf
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_save_and_load(self):
- super().test_save_and_load()
-
- # TODO(SPARK-41834): Implement SparkSession.conf
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_save_and_load_builder(self):
- super().test_save_and_load_builder()
+ pass
class ReadwriterV2ParityTests(ReadwriterV2TestsMixin, ReusedConnectTestCase):
diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py
index 7f9b5e61051..21c66284ace 100644
--- a/python/pyspark/sql/tests/test_readwriter.py
+++ b/python/pyspark/sql/tests/test_readwriter.py
@@ -31,75 +31,77 @@ class ReadwriterTestsMixin:
df = self.df
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
- df.write.json(tmpPath)
- actual = self.spark.read.json(tmpPath)
- self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
- schema = StructType([StructField("value", StringType(), True)])
- actual = self.spark.read.json(tmpPath, schema)
- self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
-
- df.write.json(tmpPath, "overwrite")
- actual = self.spark.read.json(tmpPath)
- self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
- df.write.save(
- format="json",
- mode="overwrite",
- path=tmpPath,
- noUse="this options will not be used in save.",
- )
- actual = self.spark.read.load(
- format="json", path=tmpPath, noUse="this options will not be used in load."
- )
- self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
- defaultDataSourceName = self.spark.conf.get(
- "spark.sql.sources.default", "org.apache.spark.sql.parquet"
- )
- self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
- actual = self.spark.read.load(path=tmpPath)
- self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
- self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
+ try:
+ df.write.json(tmpPath)
+ actual = self.spark.read.json(tmpPath)
+ self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+ schema = StructType([StructField("value", StringType(), True)])
+ actual = self.spark.read.json(tmpPath, schema)
+ self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
+
+ df.write.json(tmpPath, "overwrite")
+ actual = self.spark.read.json(tmpPath)
+ self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+ df.write.save(
+ format="json",
+ mode="overwrite",
+ path=tmpPath,
+ noUse="this options will not be used in save.",
+ )
+ actual = self.spark.read.load(
+ format="json", path=tmpPath, noUse="this options will not be used in load."
+ )
+ self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
- csvpath = os.path.join(tempfile.mkdtemp(), "data")
- df.write.option("quote", None).format("csv").save(csvpath)
+ try:
+ self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json").collect()
+ actual = self.spark.read.load(path=tmpPath)
+ self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+ finally:
+ self.spark.sql("RESET spark.sql.sources.default").collect()
- shutil.rmtree(tmpPath)
+ csvpath = os.path.join(tempfile.mkdtemp(), "data")
+ df.write.option("quote", None).format("csv").save(csvpath)
+ finally:
+ shutil.rmtree(tmpPath)
def test_save_and_load_builder(self):
df = self.df
tmpPath = tempfile.mkdtemp()
shutil.rmtree(tmpPath)
- df.write.json(tmpPath)
- actual = self.spark.read.json(tmpPath)
- self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
- schema = StructType([StructField("value", StringType(), True)])
- actual = self.spark.read.json(tmpPath, schema)
- self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
-
- df.write.mode("overwrite").json(tmpPath)
- actual = self.spark.read.json(tmpPath)
- self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
- df.write.mode("overwrite").options(noUse="this options will not be used in save.").option(
- "noUse", "this option will not be used in save."
- ).format("json").save(path=tmpPath)
- actual = self.spark.read.format("json").load(
- path=tmpPath, noUse="this options will not be used in load."
- )
- self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
-
- defaultDataSourceName = self.spark.conf.get(
- "spark.sql.sources.default", "org.apache.spark.sql.parquet"
- )
- self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
- actual = self.spark.read.load(path=tmpPath)
- self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
- self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
-
- shutil.rmtree(tmpPath)
+ try:
+ df.write.json(tmpPath)
+ actual = self.spark.read.json(tmpPath)
+ self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+ schema = StructType([StructField("value", StringType(), True)])
+ actual = self.spark.read.json(tmpPath, schema)
+ self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
+
+ df.write.mode("overwrite").json(tmpPath)
+ actual = self.spark.read.json(tmpPath)
+ self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+ df.write.mode("overwrite").options(
+ noUse="this options will not be used in save."
+ ).option("noUse", "this option will not be used in save.").format("json").save(
+ path=tmpPath
+ )
+ actual = self.spark.read.format("json").load(
+ path=tmpPath, noUse="this options will not be used in load."
+ )
+ self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+
+ try:
+ self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json").collect()
+ actual = self.spark.read.load(path=tmpPath)
+ self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+ finally:
+ self.spark.sql("RESET spark.sql.sources.default").collect()
+ finally:
+ shutil.rmtree(tmpPath)
def test_bucketed_write(self):
data = [
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org