You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/10/17 11:38:19 UTC
[spark] branch master updated: [SPARK-45562][SQL] XML: Make 'rowTag' a required option
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4d63ca6394f [SPARK-45562][SQL] XML: Make 'rowTag' a required option
4d63ca6394f is described below
commit 4d63ca6394fe8692e1f9bceb93606a86b88b5dc1
Author: Sandip Agarwala <13...@users.noreply.github.com>
AuthorDate: Tue Oct 17 20:38:02 2023 +0900
[SPARK-45562][SQL] XML: Make 'rowTag' a required option
### What changes were proposed in this pull request?
User can specify `rowTag` option that is the name of the XML element that maps to a `DataFrame Row`. A non-existent `rowTag` will not infer any schema or generate any `DataFrame` rows. Currently, not specifying `rowTag` option results in picking up its default value of `ROW`, which won't match a real XML element in most scenarios. This results in an empty dataframe and confuse new users.
This PR makes `rowTag` a required option for both read and write. XML built-in functions (from_xml/schema_of_xml) ignore `rowTag` option.
### Why are the changes needed?
See above
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
New unit tests
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #43389 from sandip-db/xml-rowTag.
Authored-by: Sandip Agarwala <13...@users.noreply.github.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
.../apache/spark/sql/catalyst/xml/XmlOptions.scala | 4 +-
.../execution/datasources/xml/XmlFileFormat.scala | 2 +
.../execution/datasources/xml/JavaXmlSuite.java | 10 +-
.../sql/execution/datasources/xml/XmlSuite.scala | 125 +++++++++++++++------
.../xml/parsers/StaxXmlGeneratorSuite.scala | 4 +-
5 files changed, 103 insertions(+), 42 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
index d0cfff87279..0dedbec58e1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
@@ -63,8 +63,8 @@ private[sql] class XmlOptions(
}
val compressionCodec = parameters.get(COMPRESSION).map(CompressionCodecs.getCodecClassName)
- val rowTag = parameters.getOrElse(ROW_TAG, XmlOptions.DEFAULT_ROW_TAG)
- require(rowTag.nonEmpty, s"'$ROW_TAG' option should not be empty string.")
+ val rowTag = parameters.getOrElse(ROW_TAG, XmlOptions.DEFAULT_ROW_TAG).trim
+ require(rowTag.nonEmpty, s"'$ROW_TAG' option should not be an empty string.")
require(!rowTag.startsWith("<") && !rowTag.endsWith(">"),
s"'$ROW_TAG' should not include angle brackets")
val rootTag = parameters.getOrElse(ROOT_TAG, XmlOptions.DEFAULT_ROOT_TAG)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala
index baacf7f0748..4342711b00f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala
@@ -42,6 +42,8 @@ class XmlFileFormat extends TextBasedFileFormat with DataSourceRegister {
def getXmlOptions(
sparkSession: SparkSession,
parameters: Map[String, String]): XmlOptions = {
+ val rowTagOpt = parameters.get(XmlOptions.ROW_TAG)
+ require(rowTagOpt.isDefined, s"'${XmlOptions.ROW_TAG}' option is required.")
new XmlOptions(parameters,
sparkSession.sessionState.conf.sessionLocalTimeZone,
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/execution/datasources/xml/JavaXmlSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/execution/datasources/xml/JavaXmlSuite.java
index b3f39180843..c773459dc4c 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/execution/datasources/xml/JavaXmlSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/execution/datasources/xml/JavaXmlSuite.java
@@ -82,7 +82,7 @@ public final class JavaXmlSuite {
public void testXmlParser() {
Map<String, String> options = new HashMap<>();
options.put("rowTag", booksFileTag);
- Dataset<Row> df = spark.read().options(options).format("xml").load(booksFile);
+ Dataset<Row> df = spark.read().options(options).xml(booksFile);
String prefix = XmlOptions.DEFAULT_ATTRIBUTE_PREFIX();
long result = df.select(prefix + "id").count();
Assertions.assertEquals(result, numBooks);
@@ -92,7 +92,7 @@ public final class JavaXmlSuite {
public void testLoad() {
Map<String, String> options = new HashMap<>();
options.put("rowTag", booksFileTag);
- Dataset<Row> df = spark.read().options(options).format("xml").load(booksFile);
+ Dataset<Row> df = spark.read().options(options).xml(booksFile);
long result = df.select("description").count();
Assertions.assertEquals(result, numBooks);
}
@@ -103,10 +103,10 @@ public final class JavaXmlSuite {
options.put("rowTag", booksFileTag);
Path booksPath = getEmptyTempDir().resolve("booksFile");
- Dataset<Row> df = spark.read().options(options).format("xml").load(booksFile);
- df.select("price", "description").write().format("xml").save(booksPath.toString());
+ Dataset<Row> df = spark.read().options(options).xml(booksFile);
+ df.select("price", "description").write().options(options).xml(booksPath.toString());
- Dataset<Row> newDf = spark.read().format("xml").load(booksPath.toString());
+ Dataset<Row> newDf = spark.read().options(options).xml(booksPath.toString());
long result = newDf.select("price").count();
Assertions.assertEquals(result, numBooks);
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
index c5892abf3f8..23223b3e94e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
@@ -65,6 +65,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
test("DSL test") {
val results = spark.read.format("xml")
+ .option("rowTag", "ROW")
.option("multiLine", "true")
.load(getTestResourcePath(resDir + "cars.xml"))
.select("year")
@@ -75,6 +76,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
test("DSL test with xml having unbalanced datatypes") {
val results = spark.read
+ .option("rowTag", "ROW")
.option("treatEmptyValuesAsNulls", "true")
.option("multiLine", "true")
.xml(getTestResourcePath(resDir + "gps-empty-field.xml"))
@@ -84,6 +86,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
test("DSL test with mixed elements (attributes, no child)") {
val results = spark.read
+ .option("rowTag", "ROW")
.xml(getTestResourcePath(resDir + "cars-mixed-attr-no-child.xml"))
.select("date")
.collect()
@@ -129,6 +132,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
test("DSL test for iso-8859-1 encoded file") {
val dataFrame = spark.read
+ .option("rowTag", "ROW")
.option("charset", StandardCharsets.ISO_8859_1.name)
.xml(getTestResourcePath(resDir + "cars-iso-8859-1.xml"))
assert(dataFrame.select("year").collect().length === 3)
@@ -142,6 +146,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
test("DSL test compressed file") {
val results = spark.read
+ .option("rowTag", "ROW")
.xml(getTestResourcePath(resDir + "cars.xml.gz"))
.select("year")
.collect()
@@ -151,6 +156,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
test("DSL test splittable compressed file") {
val results = spark.read
+ .option("rowTag", "ROW")
.xml(getTestResourcePath(resDir + "cars.xml.bz2"))
.select("year")
.collect()
@@ -162,6 +168,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
// val exception = intercept[UnsupportedCharsetException] {
val exception = intercept[SparkException] {
spark.read
+ .option("rowTag", "ROW")
.option("charset", "1-9588-osi")
.xml(getTestResourcePath(resDir + "cars.xml"))
.select("year")
@@ -175,7 +182,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
spark.sql(s"""
|CREATE TEMPORARY VIEW carsTable1
|USING org.apache.spark.sql.execution.datasources.xml
- |OPTIONS (path "${getTestResourcePath(resDir + "cars.xml")}")
+ |OPTIONS (rowTag "ROW", path "${getTestResourcePath(resDir + "cars.xml")}")
""".stripMargin.replaceAll("\n", " "))
assert(spark.sql("SELECT year FROM carsTable1").collect().length === 3)
@@ -185,7 +192,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
spark.sql(s"""
|CREATE TEMPORARY VIEW carsTable2
|USING xml
- |OPTIONS (path "${getTestResourcePath(resDir + "cars.xml")}")
+ |OPTIONS (rowTag "ROW", path "${getTestResourcePath(resDir + "cars.xml")}")
""".stripMargin.replaceAll("\n", " "))
assert(spark.sql("SELECT year FROM carsTable2").collect().length === 3)
@@ -193,6 +200,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
test("DSL test for parsing a malformed XML file") {
val results = spark.read
+ .option("rowTag", "ROW")
.option("mode", DropMalformedMode.name)
.xml(getTestResourcePath(resDir + "cars-malformed.xml"))
@@ -201,6 +209,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
test("DSL test for dropping malformed rows") {
val cars = spark.read
+ .option("rowTag", "ROW")
.option("mode", DropMalformedMode.name)
.xml(getTestResourcePath(resDir + "cars-malformed.xml"))
@@ -211,6 +220,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
test("DSL test for failing fast") {
val exceptionInParse = intercept[SparkException] {
spark.read
+ .option("rowTag", "ROW")
.option("mode", FailFastMode.name)
.xml(getTestResourcePath(resDir + "cars-malformed.xml"))
.collect()
@@ -245,6 +255,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
test("DSL test for permissive mode for corrupt records") {
val carsDf = spark.read
+ .option("rowTag", "ROW")
.option("mode", PermissiveMode.name)
.option("columnNameOfCorruptRecord", "_malformed_records")
.xml(getTestResourcePath(resDir + "cars-malformed.xml"))
@@ -268,6 +279,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
test("DSL test with empty file and known schema") {
val results = spark.read
+ .option("rowTag", "ROW")
.schema(buildSchema(field("column", StringType, false)))
.xml(getTestResourcePath(resDir + "empty.xml"))
.count()
@@ -283,6 +295,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
field("model"),
field("comment"))
val results = spark.read.schema(schema)
+ .option("rowTag", "ROW")
.xml(getTestResourcePath(resDir + "cars-unbalanced-elements.xml"))
.count()
@@ -293,8 +306,8 @@ class XmlSuite extends QueryTest with SharedSparkSession {
spark.sql(s"""
|CREATE TEMPORARY VIEW carsTable3
|(year double, make string, model string, comments string, grp string)
- |USING org.apache.spark.sql.execution.datasources.xml
- |OPTIONS (path "${getTestResourcePath(resDir + "empty.xml")}")
+ |USING xml
+ |OPTIONS (rowTag "ROW", path "${getTestResourcePath(resDir + "empty.xml")}")
""".stripMargin.replaceAll("\n", " "))
assert(spark.sql("SELECT count(*) FROM carsTable3").collect().head(0) === 0)
@@ -304,15 +317,15 @@ class XmlSuite extends QueryTest with SharedSparkSession {
val tempPath = getEmptyTempDir()
spark.sql(s"""
|CREATE TEMPORARY VIEW booksTableIO
- |USING org.apache.spark.sql.execution.datasources.xml
+ |USING xml
|OPTIONS (path "${getTestResourcePath(resDir + "books.xml")}", rowTag "book")
""".stripMargin.replaceAll("\n", " "))
spark.sql(s"""
|CREATE TEMPORARY VIEW booksTableEmpty
|(author string, description string, genre string,
|id string, price double, publish_date string, title string)
- |USING org.apache.spark.sql.execution.datasources.xml
- |OPTIONS (path "$tempPath")
+ |USING xml
+ |OPTIONS (rowTag "ROW", path "$tempPath")
""".stripMargin.replaceAll("\n", " "))
assert(spark.sql("SELECT * FROM booksTableIO").collect().length === 12)
@@ -329,16 +342,18 @@ class XmlSuite extends QueryTest with SharedSparkSession {
test("DSL save with gzip compression codec") {
val copyFilePath = getEmptyTempDir().resolve("cars-copy.xml")
- val cars = spark.read.xml(getTestResourcePath(resDir + "cars.xml"))
+ val cars = spark.read
+ .option("rowTag", "ROW")
+ .xml(getTestResourcePath(resDir + "cars.xml"))
cars.write
.mode(SaveMode.Overwrite)
- .options(Map("compression" -> classOf[GzipCodec].getName))
+ .options(Map("rowTag" -> "ROW", "compression" -> classOf[GzipCodec].getName))
.xml(copyFilePath.toString)
// Check that the part file has a .gz extension
assert(Files.list(copyFilePath).iterator().asScala
.count(_.getFileName.toString().endsWith(".xml.gz")) === 1)
- val carsCopy = spark.read.xml(copyFilePath.toString)
+ val carsCopy = spark.read.option("rowTag", "ROW").xml(copyFilePath.toString)
assert(carsCopy.count() === cars.count())
assert(carsCopy.collect().map(_.toString).toSet === cars.collect().map(_.toString).toSet)
@@ -347,17 +362,19 @@ class XmlSuite extends QueryTest with SharedSparkSession {
test("DSL save with gzip compression codec by shorten name") {
val copyFilePath = getEmptyTempDir().resolve("cars-copy.xml")
- val cars = spark.read.xml(getTestResourcePath(resDir + "cars.xml"))
+ val cars = spark.read
+ .option("rowTag", "ROW")
+ .xml(getTestResourcePath(resDir + "cars.xml"))
cars.write
.mode(SaveMode.Overwrite)
- .options(Map("compression" -> "gZiP"))
+ .options(Map("rowTag" -> "ROW", "compression" -> "gZiP"))
.xml(copyFilePath.toString)
// Check that the part file has a .gz extension
assert(Files.list(copyFilePath).iterator().asScala
.count(_.getFileName.toString().endsWith(".xml.gz")) === 1)
- val carsCopy = spark.read.xml(copyFilePath.toString)
+ val carsCopy = spark.read.option("rowTag", "ROW").xml(copyFilePath.toString)
assert(carsCopy.count() === cars.count())
assert(carsCopy.collect().map(_.toString).toSet === cars.collect().map(_.toString).toSet)
@@ -413,7 +430,9 @@ class XmlSuite extends QueryTest with SharedSparkSession {
test("DSL save with item") {
val tempPath = getEmptyTempDir().resolve("items-temp.xml")
val items = spark.createDataFrame(Seq(Tuple1(Array(Array(3, 4))))).toDF("thing").repartition(1)
- items.write.option("arrayElementName", "foo").xml(tempPath.toString)
+ items.write
+ .option("rowTag", "ROW")
+ .option("arrayElementName", "foo").xml(tempPath.toString)
val xmlFile =
Files.list(tempPath).iterator.asScala
@@ -474,7 +493,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
val data = spark.sparkContext.parallelize(
List(List(List("aa", "bb"), List("aa", "bb"))).map(Row(_)))
val df = spark.createDataFrame(data, schema)
- df.write.xml(copyFilePath.toString)
+ df.write.option("rowTag", "ROW").xml(copyFilePath.toString)
// When [[ArrayType]] has [[ArrayType]] as elements, it is confusing what is the element
// name for XML file. Now, it is "item" by default. So, "item" field is additionally added
@@ -482,7 +501,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
val schemaCopy = buildSchema(
structArray("a",
field(XmlOptions.DEFAULT_ARRAY_ELEMENT_NAME, ArrayType(StringType))))
- val dfCopy = spark.read.xml(copyFilePath.toString)
+ val dfCopy = spark.read.option("rowTag", "ROW").xml(copyFilePath.toString)
assert(dfCopy.count() === df.count())
assert(dfCopy.schema === schemaCopy)
@@ -518,9 +537,9 @@ class XmlSuite extends QueryTest with SharedSparkSession {
val data = spark.sparkContext.parallelize(Seq(row))
val df = spark.createDataFrame(data, schema)
- df.write.xml(copyFilePath.toString)
+ df.write.option("rowTag", "ROW").xml(copyFilePath.toString)
- val dfCopy = spark.read.schema(schema)
+ val dfCopy = spark.read.option("rowTag", "ROW").schema(schema)
.xml(copyFilePath.toString)
assert(dfCopy.collect() === df.collect())
@@ -685,7 +704,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
field("comment"),
field("color"),
field("year", IntegerType))
- val results = spark.read.schema(schema)
+ val results = spark.read.option("rowTag", "ROW").schema(schema)
.xml(getTestResourcePath(resDir + "cars-unbalanced-elements.xml"))
.count()
@@ -693,7 +712,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
}
test("DSL test inferred schema passed through") {
- val dataFrame = spark.read.xml(getTestResourcePath(resDir + "cars.xml"))
+ val dataFrame = spark.read.option("rowTag", "ROW").xml(getTestResourcePath(resDir + "cars.xml"))
val results = dataFrame
.select("comment", "year")
@@ -706,7 +725,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
val schema = buildSchema(
field("name", StringType, false),
field("age"))
- val results = spark.read.schema(schema)
+ val results = spark.read.option("rowTag", "ROW").schema(schema)
.xml(getTestResourcePath(resDir + "null-numbers.xml"))
.select("name", "age")
.collect()
@@ -721,6 +740,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
field("name", StringType, false),
field("age", IntegerType))
val results = spark.read.schema(schema)
+ .option("rowTag", "ROW")
.option("treatEmptyValuesAsNulls", true)
.xml(getTestResourcePath(resDir + "null-numbers.xml"))
.select("name", "age")
@@ -808,6 +828,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
field("a", IntegerType)))
val result = spark.read.schema(schema)
+ .option("rowTag", "ROW")
.xml(getTestResourcePath(resDir + "simple-nested-objects.xml"))
.select("c.a", "c.b")
.collect()
@@ -858,7 +879,9 @@ class XmlSuite extends QueryTest with SharedSparkSession {
}
test("Skip and project currently XML files without indentation") {
- val df = spark.read.xml(getTestResourcePath(resDir + "cars-no-indentation.xml"))
+ val df = spark.read
+ .option("rowTag", "ROW")
+ .xml(getTestResourcePath(resDir + "cars-no-indentation.xml"))
val results = df.select("model").collect()
val years = results.map(_(0)).toSet
assert(years === Set("S", "E350", "Volt"))
@@ -880,10 +903,11 @@ class XmlSuite extends QueryTest with SharedSparkSession {
val messageOne = intercept[IllegalArgumentException] {
spark.read.option("rowTag", "").xml(getTestResourcePath(resDir + "cars.xml"))
}.getMessage
- assert(messageOne === "requirement failed: 'rowTag' option should not be empty string.")
+ assert(messageOne === "requirement failed: 'rowTag' option should not be an empty string.")
val messageThree = intercept[IllegalArgumentException] {
- spark.read.option("valueTag", "").xml(getTestResourcePath(resDir + "cars.xml"))
+ spark.read.option("rowTag", "ROW")
+ .option("valueTag", "").xml(getTestResourcePath(resDir + "cars.xml"))
}.getMessage
assert(messageThree === "requirement failed: 'valueTag' option should not be empty string.")
}
@@ -895,18 +919,21 @@ class XmlSuite extends QueryTest with SharedSparkSession {
assert(messageOne === "requirement failed: 'rowTag' should not include angle brackets")
val messageTwo = intercept[IllegalArgumentException] {
- spark.read.option("rowTag", "<ROW").xml(getTestResourcePath(resDir + "cars.xml"))
+ spark.read.option("rowTag", "ROW")
+ .option("rowTag", "<ROW").xml(getTestResourcePath(resDir + "cars.xml"))
}.getMessage
assert(
messageTwo === "requirement failed: 'rowTag' should not include angle brackets")
val messageThree = intercept[IllegalArgumentException] {
- spark.read.option("rootTag", "ROWSET>").xml(getTestResourcePath(resDir + "cars.xml"))
+ spark.read.option("rowTag", "ROW")
+ .option("rootTag", "ROWSET>").xml(getTestResourcePath(resDir + "cars.xml"))
}.getMessage
assert(messageThree === "requirement failed: 'rootTag' should not include angle brackets")
val messageFour = intercept[IllegalArgumentException] {
- spark.read.option("rootTag", "<ROWSET").xml(getTestResourcePath(resDir + "cars.xml"))
+ spark.read.option("rowTag", "ROW")
+ .option("rootTag", "<ROWSET").xml(getTestResourcePath(resDir + "cars.xml"))
}.getMessage
assert(messageFour === "requirement failed: 'rootTag' should not include angle brackets")
}
@@ -914,6 +941,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
test("valueTag and attributePrefix should not be the same.") {
val messageOne = intercept[IllegalArgumentException] {
spark.read
+ .option("rowTag", "ROW")
.option("valueTag", "#abc")
.option("attributePrefix", "#abc")
.xml(getTestResourcePath(resDir + "cars.xml"))
@@ -924,6 +952,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
test("nullValue and treatEmptyValuesAsNulls test") {
val resultsOne = spark.read
+ .option("rowTag", "ROW")
.option("treatEmptyValuesAsNulls", "true")
.xml(getTestResourcePath(resDir + "gps-empty-field.xml"))
assert(resultsOne.selectExpr("extensions.TrackPointExtension").head().getStruct(0) !== null)
@@ -934,6 +963,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
assert(resultsOne.collect().length === 2)
val resultsTwo = spark.read
+ .option("rowTag", "ROW")
.option("nullValue", "2013-01-24T06:18:43Z")
.xml(getTestResourcePath(resDir + "gps-empty-field.xml"))
assert(resultsTwo.selectExpr("time").head().getStruct(0) === null)
@@ -1015,7 +1045,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
field("non-empty-tag", IntegerType),
field("self-closing-tag", IntegerType))
- val result = spark.read.schema(schema)
+ val result = spark.read.option("rowTag", "ROW").schema(schema)
.xml(getTestResourcePath(resDir + "self-closing-tag.xml"))
.collect()
@@ -1054,6 +1084,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
field("integer_map", MapType(StringType, IntegerType)),
field("_malformed_records", StringType))
val results = spark.read
+ .option("rowTag", "ROW")
.option("mode", "PERMISSIVE")
.option("columnNameOfCorruptRecord", "_malformed_records")
.schema(schema)
@@ -1178,7 +1209,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
"<ROW><year>2015</year><make>Chevy</make><model>Volt</model><comment>No</comment></ROW>")
val xmlRDD = spark.sparkContext.parallelize(data)
val ds = spark.createDataset(xmlRDD)(Encoders.STRING)
- assert(spark.read.xml(ds).collect().length === 3)
+ assert(spark.read.option("rowTag", "ROW").xml(ds).collect().length === 3)
}
import testImplicits._
@@ -1308,10 +1339,11 @@ class XmlSuite extends QueryTest with SharedSparkSession {
test("rootTag with simple attributes") {
val xmlPath = getEmptyTempDir().resolve("simple_attributes")
val df = spark.createDataFrame(Seq((42, "foo"))).toDF("number", "value").repartition(1)
- df.write.
- option("rootTag", "root foo='bar' bing=\"baz\"").
- option("declaration", "").
- xml(xmlPath.toString)
+ df.write
+ .option("rowTag", "ROW")
+ .option("rootTag", "root foo='bar' bing=\"baz\"")
+ .option("declaration", "")
+ .xml(xmlPath.toString)
val xmlFile =
Files.list(xmlPath).iterator.asScala.filter(_.getFileName.toString.startsWith("part-")).next()
@@ -1651,10 +1683,12 @@ class XmlSuite extends QueryTest with SharedSparkSession {
val results = Seq(
// user specified schema
spark.read
+ .option("rowTag", "ROW")
.schema(schema)
.xml(getTestResourcePath(resDir + "root-level-value.xml")).collect(),
// schema inference
spark.read
+ .option("rowTag", "ROW")
.xml(getTestResourcePath(resDir + "root-level-value.xml")).collect())
results.foreach { result =>
assert(result.length === 3)
@@ -1677,10 +1711,12 @@ class XmlSuite extends QueryTest with SharedSparkSession {
val dfs = Seq(
// user specified schema
spark.read
+ .option("rowTag", "ROW")
.schema(schema)
.xml(getTestResourcePath(resDir + "root-level-value-none.xml")),
// schema inference
spark.read
+ .option("rowTag", "ROW")
.xml(getTestResourcePath(resDir + "root-level-value-none.xml"))
)
dfs.foreach { df =>
@@ -1720,4 +1756,27 @@ class XmlSuite extends QueryTest with SharedSparkSession {
assert(result.select("decoded._VALUE").head().getLong(0) === 123456L)
assert(result.select("decoded._attr").head().getString(0) === "attr1")
}
+
+ test("Test XML Options Error Messages") {
+ def checkXmlOptionErrorMessage(
+ parameters: Map[String, String] = Map.empty,
+ msg: String): Unit = {
+ val e = intercept[IllegalArgumentException] {
+ spark.read
+ .options(parameters)
+ .xml(getTestResourcePath(resDir + "ages.xml"))
+ .collect()
+ }
+ assert(e.getMessage.contains(msg))
+ }
+
+ checkXmlOptionErrorMessage(Map.empty, "'rowTag' option is required.")
+ checkXmlOptionErrorMessage(Map("rowTag" -> ""),
+ "'rowTag' option should not be an empty string.")
+ checkXmlOptionErrorMessage(Map("rowTag" -> " "),
+ "'rowTag' option should not be an empty string.")
+ checkXmlOptionErrorMessage(Map("rowTag" -> "person",
+ "declaration" -> s"<${XmlOptions.DEFAULT_DECLARATION}>"),
+ "'declaration' should not include angle brackets")
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlGeneratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlGeneratorSuite.scala
index 176cfd98563..1798d32d8a2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlGeneratorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlGeneratorSuite.scala
@@ -69,9 +69,9 @@ final class StaxXmlGeneratorSuite extends SharedSparkSession {
val df = dataset.toDF().orderBy("booleanDatum")
val targetFile =
Files.createTempDirectory("StaxXmlGeneratorSuite").resolve("roundtrip.xml").toString
- df.write.format("xml").save(targetFile)
+ df.write.option("rowTag", "ROW").xml(targetFile)
val newDf =
- spark.read.schema(df.schema).format("xml").load(targetFile).orderBy("booleanDatum")
+ spark.read.option("rowTag", "ROW").schema(df.schema).xml(targetFile).orderBy("booleanDatum")
assert(df.collect().toSeq === newDf.collect().toSeq)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org