You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/10/17 11:38:19 UTC

[spark] branch master updated: [SPARK-45562][SQL] XML: Make 'rowTag' a required option

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4d63ca6394f [SPARK-45562][SQL] XML: Make 'rowTag' a required option
4d63ca6394f is described below

commit 4d63ca6394fe8692e1f9bceb93606a86b88b5dc1
Author: Sandip Agarwala <13...@users.noreply.github.com>
AuthorDate: Tue Oct 17 20:38:02 2023 +0900

    [SPARK-45562][SQL] XML: Make 'rowTag' a required option
    
    ### What changes were proposed in this pull request?
    User can specify `rowTag` option that is the name of the XML element that maps to a `DataFrame Row`.  A non-existent `rowTag` will not infer any schema or generate any `DataFrame` rows. Currently, not specifying `rowTag` option results in picking up its default value of  `ROW`, which won't match a real XML element in most scenarios. This results in an empty dataframe and confuse new users.
    
    This PR makes `rowTag` a required option for both read and write. XML built-in functions (from_xml/schema_of_xml) ignore `rowTag` option.
    
    ### Why are the changes needed?
    See above
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    New unit tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #43389 from sandip-db/xml-rowTag.
    
    Authored-by: Sandip Agarwala <13...@users.noreply.github.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 .../apache/spark/sql/catalyst/xml/XmlOptions.scala |   4 +-
 .../execution/datasources/xml/XmlFileFormat.scala  |   2 +
 .../execution/datasources/xml/JavaXmlSuite.java    |  10 +-
 .../sql/execution/datasources/xml/XmlSuite.scala   | 125 +++++++++++++++------
 .../xml/parsers/StaxXmlGeneratorSuite.scala        |   4 +-
 5 files changed, 103 insertions(+), 42 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
index d0cfff87279..0dedbec58e1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlOptions.scala
@@ -63,8 +63,8 @@ private[sql] class XmlOptions(
   }
 
   val compressionCodec = parameters.get(COMPRESSION).map(CompressionCodecs.getCodecClassName)
-  val rowTag = parameters.getOrElse(ROW_TAG, XmlOptions.DEFAULT_ROW_TAG)
-  require(rowTag.nonEmpty, s"'$ROW_TAG' option should not be empty string.")
+  val rowTag = parameters.getOrElse(ROW_TAG, XmlOptions.DEFAULT_ROW_TAG).trim
+  require(rowTag.nonEmpty, s"'$ROW_TAG' option should not be an empty string.")
   require(!rowTag.startsWith("<") && !rowTag.endsWith(">"),
           s"'$ROW_TAG' should not include angle brackets")
   val rootTag = parameters.getOrElse(ROOT_TAG, XmlOptions.DEFAULT_ROOT_TAG)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala
index baacf7f0748..4342711b00f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/xml/XmlFileFormat.scala
@@ -42,6 +42,8 @@ class XmlFileFormat extends TextBasedFileFormat with DataSourceRegister {
   def getXmlOptions(
       sparkSession: SparkSession,
       parameters: Map[String, String]): XmlOptions = {
+    val rowTagOpt = parameters.get(XmlOptions.ROW_TAG)
+    require(rowTagOpt.isDefined, s"'${XmlOptions.ROW_TAG}' option is required.")
     new XmlOptions(parameters,
       sparkSession.sessionState.conf.sessionLocalTimeZone,
       sparkSession.sessionState.conf.columnNameOfCorruptRecord)
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/execution/datasources/xml/JavaXmlSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/execution/datasources/xml/JavaXmlSuite.java
index b3f39180843..c773459dc4c 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/execution/datasources/xml/JavaXmlSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/execution/datasources/xml/JavaXmlSuite.java
@@ -82,7 +82,7 @@ public final class JavaXmlSuite {
     public void testXmlParser() {
         Map<String, String> options = new HashMap<>();
         options.put("rowTag", booksFileTag);
-        Dataset<Row> df = spark.read().options(options).format("xml").load(booksFile);
+        Dataset<Row> df = spark.read().options(options).xml(booksFile);
         String prefix = XmlOptions.DEFAULT_ATTRIBUTE_PREFIX();
         long result = df.select(prefix + "id").count();
         Assertions.assertEquals(result, numBooks);
@@ -92,7 +92,7 @@ public final class JavaXmlSuite {
     public void testLoad() {
         Map<String, String> options = new HashMap<>();
         options.put("rowTag", booksFileTag);
-        Dataset<Row> df = spark.read().options(options).format("xml").load(booksFile);
+        Dataset<Row> df = spark.read().options(options).xml(booksFile);
         long result = df.select("description").count();
         Assertions.assertEquals(result, numBooks);
     }
@@ -103,10 +103,10 @@ public final class JavaXmlSuite {
         options.put("rowTag", booksFileTag);
         Path booksPath = getEmptyTempDir().resolve("booksFile");
 
-        Dataset<Row> df = spark.read().options(options).format("xml").load(booksFile);
-        df.select("price", "description").write().format("xml").save(booksPath.toString());
+        Dataset<Row> df = spark.read().options(options).xml(booksFile);
+        df.select("price", "description").write().options(options).xml(booksPath.toString());
 
-        Dataset<Row> newDf = spark.read().format("xml").load(booksPath.toString());
+        Dataset<Row> newDf = spark.read().options(options).xml(booksPath.toString());
         long result = newDf.select("price").count();
         Assertions.assertEquals(result, numBooks);
     }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
index c5892abf3f8..23223b3e94e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala
@@ -65,6 +65,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
 
   test("DSL test") {
     val results = spark.read.format("xml")
+      .option("rowTag", "ROW")
       .option("multiLine", "true")
       .load(getTestResourcePath(resDir + "cars.xml"))
       .select("year")
@@ -75,6 +76,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
 
   test("DSL test with xml having unbalanced datatypes") {
     val results = spark.read
+      .option("rowTag", "ROW")
       .option("treatEmptyValuesAsNulls", "true")
       .option("multiLine", "true")
       .xml(getTestResourcePath(resDir + "gps-empty-field.xml"))
@@ -84,6 +86,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
 
   test("DSL test with mixed elements (attributes, no child)") {
     val results = spark.read
+      .option("rowTag", "ROW")
       .xml(getTestResourcePath(resDir + "cars-mixed-attr-no-child.xml"))
       .select("date")
       .collect()
@@ -129,6 +132,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
 
   test("DSL test for iso-8859-1 encoded file") {
     val dataFrame = spark.read
+      .option("rowTag", "ROW")
       .option("charset", StandardCharsets.ISO_8859_1.name)
       .xml(getTestResourcePath(resDir + "cars-iso-8859-1.xml"))
     assert(dataFrame.select("year").collect().length === 3)
@@ -142,6 +146,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
 
   test("DSL test compressed file") {
     val results = spark.read
+      .option("rowTag", "ROW")
       .xml(getTestResourcePath(resDir + "cars.xml.gz"))
       .select("year")
       .collect()
@@ -151,6 +156,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
 
   test("DSL test splittable compressed file") {
     val results = spark.read
+      .option("rowTag", "ROW")
       .xml(getTestResourcePath(resDir + "cars.xml.bz2"))
       .select("year")
       .collect()
@@ -162,6 +168,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
     // val exception = intercept[UnsupportedCharsetException] {
     val exception = intercept[SparkException] {
       spark.read
+        .option("rowTag", "ROW")
         .option("charset", "1-9588-osi")
         .xml(getTestResourcePath(resDir + "cars.xml"))
         .select("year")
@@ -175,7 +182,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
     spark.sql(s"""
          |CREATE TEMPORARY VIEW carsTable1
          |USING org.apache.spark.sql.execution.datasources.xml
-         |OPTIONS (path "${getTestResourcePath(resDir + "cars.xml")}")
+         |OPTIONS (rowTag "ROW", path "${getTestResourcePath(resDir + "cars.xml")}")
       """.stripMargin.replaceAll("\n", " "))
 
     assert(spark.sql("SELECT year FROM carsTable1").collect().length === 3)
@@ -185,7 +192,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
     spark.sql(s"""
          |CREATE TEMPORARY VIEW carsTable2
          |USING xml
-         |OPTIONS (path "${getTestResourcePath(resDir + "cars.xml")}")
+         |OPTIONS (rowTag "ROW", path "${getTestResourcePath(resDir + "cars.xml")}")
       """.stripMargin.replaceAll("\n", " "))
 
     assert(spark.sql("SELECT year FROM carsTable2").collect().length === 3)
@@ -193,6 +200,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
 
   test("DSL test for parsing a malformed XML file") {
     val results = spark.read
+      .option("rowTag", "ROW")
       .option("mode", DropMalformedMode.name)
       .xml(getTestResourcePath(resDir + "cars-malformed.xml"))
 
@@ -201,6 +209,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
 
   test("DSL test for dropping malformed rows") {
     val cars = spark.read
+      .option("rowTag", "ROW")
       .option("mode", DropMalformedMode.name)
       .xml(getTestResourcePath(resDir + "cars-malformed.xml"))
 
@@ -211,6 +220,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
   test("DSL test for failing fast") {
     val exceptionInParse = intercept[SparkException] {
       spark.read
+        .option("rowTag", "ROW")
         .option("mode", FailFastMode.name)
         .xml(getTestResourcePath(resDir + "cars-malformed.xml"))
         .collect()
@@ -245,6 +255,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
 
   test("DSL test for permissive mode for corrupt records") {
     val carsDf = spark.read
+      .option("rowTag", "ROW")
       .option("mode", PermissiveMode.name)
       .option("columnNameOfCorruptRecord", "_malformed_records")
       .xml(getTestResourcePath(resDir + "cars-malformed.xml"))
@@ -268,6 +279,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
 
   test("DSL test with empty file and known schema") {
     val results = spark.read
+      .option("rowTag", "ROW")
       .schema(buildSchema(field("column", StringType, false)))
       .xml(getTestResourcePath(resDir + "empty.xml"))
       .count()
@@ -283,6 +295,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
       field("model"),
       field("comment"))
     val results = spark.read.schema(schema)
+      .option("rowTag", "ROW")
       .xml(getTestResourcePath(resDir + "cars-unbalanced-elements.xml"))
       .count()
 
@@ -293,8 +306,8 @@ class XmlSuite extends QueryTest with SharedSparkSession {
     spark.sql(s"""
            |CREATE TEMPORARY VIEW carsTable3
            |(year double, make string, model string, comments string, grp string)
-           |USING org.apache.spark.sql.execution.datasources.xml
-           |OPTIONS (path "${getTestResourcePath(resDir + "empty.xml")}")
+           |USING xml
+           |OPTIONS (rowTag "ROW", path "${getTestResourcePath(resDir + "empty.xml")}")
       """.stripMargin.replaceAll("\n", " "))
 
     assert(spark.sql("SELECT count(*) FROM carsTable3").collect().head(0) === 0)
@@ -304,15 +317,15 @@ class XmlSuite extends QueryTest with SharedSparkSession {
     val tempPath = getEmptyTempDir()
     spark.sql(s"""
          |CREATE TEMPORARY VIEW booksTableIO
-         |USING org.apache.spark.sql.execution.datasources.xml
+         |USING xml
          |OPTIONS (path "${getTestResourcePath(resDir + "books.xml")}", rowTag "book")
       """.stripMargin.replaceAll("\n", " "))
     spark.sql(s"""
          |CREATE TEMPORARY VIEW booksTableEmpty
          |(author string, description string, genre string,
          |id string, price double, publish_date string, title string)
-         |USING org.apache.spark.sql.execution.datasources.xml
-         |OPTIONS (path "$tempPath")
+         |USING xml
+         |OPTIONS (rowTag "ROW", path "$tempPath")
       """.stripMargin.replaceAll("\n", " "))
 
     assert(spark.sql("SELECT * FROM booksTableIO").collect().length === 12)
@@ -329,16 +342,18 @@ class XmlSuite extends QueryTest with SharedSparkSession {
   test("DSL save with gzip compression codec") {
     val copyFilePath = getEmptyTempDir().resolve("cars-copy.xml")
 
-    val cars = spark.read.xml(getTestResourcePath(resDir + "cars.xml"))
+    val cars = spark.read
+      .option("rowTag", "ROW")
+      .xml(getTestResourcePath(resDir + "cars.xml"))
     cars.write
       .mode(SaveMode.Overwrite)
-      .options(Map("compression" -> classOf[GzipCodec].getName))
+      .options(Map("rowTag" -> "ROW", "compression" -> classOf[GzipCodec].getName))
       .xml(copyFilePath.toString)
     // Check that the part file has a .gz extension
     assert(Files.list(copyFilePath).iterator().asScala
       .count(_.getFileName.toString().endsWith(".xml.gz")) === 1)
 
-    val carsCopy = spark.read.xml(copyFilePath.toString)
+    val carsCopy = spark.read.option("rowTag", "ROW").xml(copyFilePath.toString)
 
     assert(carsCopy.count() === cars.count())
     assert(carsCopy.collect().map(_.toString).toSet === cars.collect().map(_.toString).toSet)
@@ -347,17 +362,19 @@ class XmlSuite extends QueryTest with SharedSparkSession {
   test("DSL save with gzip compression codec by shorten name") {
     val copyFilePath = getEmptyTempDir().resolve("cars-copy.xml")
 
-    val cars = spark.read.xml(getTestResourcePath(resDir + "cars.xml"))
+    val cars = spark.read
+      .option("rowTag", "ROW")
+      .xml(getTestResourcePath(resDir + "cars.xml"))
     cars.write
       .mode(SaveMode.Overwrite)
-      .options(Map("compression" -> "gZiP"))
+      .options(Map("rowTag" -> "ROW", "compression" -> "gZiP"))
       .xml(copyFilePath.toString)
 
     // Check that the part file has a .gz extension
     assert(Files.list(copyFilePath).iterator().asScala
       .count(_.getFileName.toString().endsWith(".xml.gz")) === 1)
 
-    val carsCopy = spark.read.xml(copyFilePath.toString)
+    val carsCopy = spark.read.option("rowTag", "ROW").xml(copyFilePath.toString)
 
     assert(carsCopy.count() === cars.count())
     assert(carsCopy.collect().map(_.toString).toSet === cars.collect().map(_.toString).toSet)
@@ -413,7 +430,9 @@ class XmlSuite extends QueryTest with SharedSparkSession {
   test("DSL save with item") {
     val tempPath = getEmptyTempDir().resolve("items-temp.xml")
     val items = spark.createDataFrame(Seq(Tuple1(Array(Array(3, 4))))).toDF("thing").repartition(1)
-    items.write.option("arrayElementName", "foo").xml(tempPath.toString)
+    items.write
+      .option("rowTag", "ROW")
+      .option("arrayElementName", "foo").xml(tempPath.toString)
 
     val xmlFile =
       Files.list(tempPath).iterator.asScala
@@ -474,7 +493,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
     val data = spark.sparkContext.parallelize(
       List(List(List("aa", "bb"), List("aa", "bb"))).map(Row(_)))
     val df = spark.createDataFrame(data, schema)
-    df.write.xml(copyFilePath.toString)
+    df.write.option("rowTag", "ROW").xml(copyFilePath.toString)
 
     // When [[ArrayType]] has [[ArrayType]] as elements, it is confusing what is the element
     // name for XML file. Now, it is "item" by default. So, "item" field is additionally added
@@ -482,7 +501,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
     val schemaCopy = buildSchema(
       structArray("a",
         field(XmlOptions.DEFAULT_ARRAY_ELEMENT_NAME, ArrayType(StringType))))
-    val dfCopy = spark.read.xml(copyFilePath.toString)
+    val dfCopy = spark.read.option("rowTag", "ROW").xml(copyFilePath.toString)
 
     assert(dfCopy.count() === df.count())
     assert(dfCopy.schema === schemaCopy)
@@ -518,9 +537,9 @@ class XmlSuite extends QueryTest with SharedSparkSession {
       val data = spark.sparkContext.parallelize(Seq(row))
 
       val df = spark.createDataFrame(data, schema)
-      df.write.xml(copyFilePath.toString)
+      df.write.option("rowTag", "ROW").xml(copyFilePath.toString)
 
-      val dfCopy = spark.read.schema(schema)
+      val dfCopy = spark.read.option("rowTag", "ROW").schema(schema)
         .xml(copyFilePath.toString)
 
       assert(dfCopy.collect() === df.collect())
@@ -685,7 +704,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
       field("comment"),
       field("color"),
       field("year", IntegerType))
-    val results = spark.read.schema(schema)
+    val results = spark.read.option("rowTag", "ROW").schema(schema)
       .xml(getTestResourcePath(resDir + "cars-unbalanced-elements.xml"))
       .count()
 
@@ -693,7 +712,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
   }
 
   test("DSL test inferred schema passed through") {
-    val dataFrame = spark.read.xml(getTestResourcePath(resDir + "cars.xml"))
+    val dataFrame = spark.read.option("rowTag", "ROW").xml(getTestResourcePath(resDir + "cars.xml"))
 
     val results = dataFrame
       .select("comment", "year")
@@ -706,7 +725,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
     val schema = buildSchema(
       field("name", StringType, false),
       field("age"))
-    val results = spark.read.schema(schema)
+    val results = spark.read.option("rowTag", "ROW").schema(schema)
       .xml(getTestResourcePath(resDir + "null-numbers.xml"))
       .select("name", "age")
       .collect()
@@ -721,6 +740,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
       field("name", StringType, false),
       field("age", IntegerType))
     val results = spark.read.schema(schema)
+      .option("rowTag", "ROW")
       .option("treatEmptyValuesAsNulls", true)
       .xml(getTestResourcePath(resDir + "null-numbers.xml"))
       .select("name", "age")
@@ -808,6 +828,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
         field("a", IntegerType)))
 
     val result = spark.read.schema(schema)
+      .option("rowTag", "ROW")
       .xml(getTestResourcePath(resDir + "simple-nested-objects.xml"))
       .select("c.a", "c.b")
       .collect()
@@ -858,7 +879,9 @@ class XmlSuite extends QueryTest with SharedSparkSession {
   }
 
   test("Skip and project currently XML files without indentation") {
-    val df = spark.read.xml(getTestResourcePath(resDir + "cars-no-indentation.xml"))
+    val df = spark.read
+      .option("rowTag", "ROW")
+      .xml(getTestResourcePath(resDir + "cars-no-indentation.xml"))
     val results = df.select("model").collect()
     val years = results.map(_(0)).toSet
     assert(years === Set("S", "E350", "Volt"))
@@ -880,10 +903,11 @@ class XmlSuite extends QueryTest with SharedSparkSession {
     val messageOne = intercept[IllegalArgumentException] {
       spark.read.option("rowTag", "").xml(getTestResourcePath(resDir + "cars.xml"))
     }.getMessage
-    assert(messageOne === "requirement failed: 'rowTag' option should not be empty string.")
+    assert(messageOne === "requirement failed: 'rowTag' option should not be an empty string.")
 
     val messageThree = intercept[IllegalArgumentException] {
-      spark.read.option("valueTag", "").xml(getTestResourcePath(resDir + "cars.xml"))
+      spark.read.option("rowTag", "ROW")
+        .option("valueTag", "").xml(getTestResourcePath(resDir + "cars.xml"))
     }.getMessage
     assert(messageThree === "requirement failed: 'valueTag' option should not be empty string.")
   }
@@ -895,18 +919,21 @@ class XmlSuite extends QueryTest with SharedSparkSession {
     assert(messageOne === "requirement failed: 'rowTag' should not include angle brackets")
 
     val messageTwo = intercept[IllegalArgumentException] {
-            spark.read.option("rowTag", "<ROW").xml(getTestResourcePath(resDir + "cars.xml"))
+            spark.read.option("rowTag", "ROW")
+              .option("rowTag", "<ROW").xml(getTestResourcePath(resDir + "cars.xml"))
     }.getMessage
     assert(
       messageTwo === "requirement failed: 'rowTag' should not include angle brackets")
 
     val messageThree = intercept[IllegalArgumentException] {
-      spark.read.option("rootTag", "ROWSET>").xml(getTestResourcePath(resDir + "cars.xml"))
+      spark.read.option("rowTag", "ROW")
+        .option("rootTag", "ROWSET>").xml(getTestResourcePath(resDir + "cars.xml"))
     }.getMessage
     assert(messageThree === "requirement failed: 'rootTag' should not include angle brackets")
 
     val messageFour = intercept[IllegalArgumentException] {
-      spark.read.option("rootTag", "<ROWSET").xml(getTestResourcePath(resDir + "cars.xml"))
+      spark.read.option("rowTag", "ROW")
+        .option("rootTag", "<ROWSET").xml(getTestResourcePath(resDir + "cars.xml"))
     }.getMessage
     assert(messageFour === "requirement failed: 'rootTag' should not include angle brackets")
   }
@@ -914,6 +941,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
   test("valueTag and attributePrefix should not be the same.") {
     val messageOne = intercept[IllegalArgumentException] {
       spark.read
+        .option("rowTag", "ROW")
         .option("valueTag", "#abc")
         .option("attributePrefix", "#abc")
         .xml(getTestResourcePath(resDir + "cars.xml"))
@@ -924,6 +952,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
 
   test("nullValue and treatEmptyValuesAsNulls test") {
     val resultsOne = spark.read
+      .option("rowTag", "ROW")
       .option("treatEmptyValuesAsNulls", "true")
       .xml(getTestResourcePath(resDir + "gps-empty-field.xml"))
     assert(resultsOne.selectExpr("extensions.TrackPointExtension").head().getStruct(0) !== null)
@@ -934,6 +963,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
     assert(resultsOne.collect().length === 2)
 
     val resultsTwo = spark.read
+      .option("rowTag", "ROW")
       .option("nullValue", "2013-01-24T06:18:43Z")
       .xml(getTestResourcePath(resDir + "gps-empty-field.xml"))
     assert(resultsTwo.selectExpr("time").head().getStruct(0) === null)
@@ -1015,7 +1045,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
       field("non-empty-tag", IntegerType),
       field("self-closing-tag", IntegerType))
 
-    val result = spark.read.schema(schema)
+    val result = spark.read.option("rowTag", "ROW").schema(schema)
       .xml(getTestResourcePath(resDir + "self-closing-tag.xml"))
       .collect()
 
@@ -1054,6 +1084,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
       field("integer_map", MapType(StringType, IntegerType)),
       field("_malformed_records", StringType))
     val results = spark.read
+      .option("rowTag", "ROW")
       .option("mode", "PERMISSIVE")
       .option("columnNameOfCorruptRecord", "_malformed_records")
       .schema(schema)
@@ -1178,7 +1209,7 @@ class XmlSuite extends QueryTest with SharedSparkSession {
       "<ROW><year>2015</year><make>Chevy</make><model>Volt</model><comment>No</comment></ROW>")
     val xmlRDD = spark.sparkContext.parallelize(data)
     val ds = spark.createDataset(xmlRDD)(Encoders.STRING)
-    assert(spark.read.xml(ds).collect().length === 3)
+    assert(spark.read.option("rowTag", "ROW").xml(ds).collect().length === 3)
   }
 
   import testImplicits._
@@ -1308,10 +1339,11 @@ class XmlSuite extends QueryTest with SharedSparkSession {
   test("rootTag with simple attributes") {
     val xmlPath = getEmptyTempDir().resolve("simple_attributes")
     val df = spark.createDataFrame(Seq((42, "foo"))).toDF("number", "value").repartition(1)
-    df.write.
-      option("rootTag", "root foo='bar' bing=\"baz\"").
-      option("declaration", "").
-      xml(xmlPath.toString)
+    df.write
+      .option("rowTag", "ROW")
+      .option("rootTag", "root foo='bar' bing=\"baz\"")
+      .option("declaration", "")
+      .xml(xmlPath.toString)
 
     val xmlFile =
       Files.list(xmlPath).iterator.asScala.filter(_.getFileName.toString.startsWith("part-")).next()
@@ -1651,10 +1683,12 @@ class XmlSuite extends QueryTest with SharedSparkSession {
     val results = Seq(
       // user specified schema
       spark.read
+        .option("rowTag", "ROW")
         .schema(schema)
         .xml(getTestResourcePath(resDir + "root-level-value.xml")).collect(),
       // schema inference
       spark.read
+        .option("rowTag", "ROW")
         .xml(getTestResourcePath(resDir + "root-level-value.xml")).collect())
     results.foreach { result =>
       assert(result.length === 3)
@@ -1677,10 +1711,12 @@ class XmlSuite extends QueryTest with SharedSparkSession {
     val dfs = Seq(
       // user specified schema
       spark.read
+        .option("rowTag", "ROW")
         .schema(schema)
         .xml(getTestResourcePath(resDir + "root-level-value-none.xml")),
       // schema inference
       spark.read
+        .option("rowTag", "ROW")
         .xml(getTestResourcePath(resDir + "root-level-value-none.xml"))
     )
     dfs.foreach { df =>
@@ -1720,4 +1756,27 @@ class XmlSuite extends QueryTest with SharedSparkSession {
     assert(result.select("decoded._VALUE").head().getLong(0) === 123456L)
     assert(result.select("decoded._attr").head().getString(0) === "attr1")
   }
+
+  test("Test XML Options Error Messages") {
+    def checkXmlOptionErrorMessage(
+      parameters: Map[String, String] = Map.empty,
+      msg: String): Unit = {
+      val e = intercept[IllegalArgumentException] {
+        spark.read
+          .options(parameters)
+          .xml(getTestResourcePath(resDir + "ages.xml"))
+          .collect()
+      }
+      assert(e.getMessage.contains(msg))
+    }
+
+    checkXmlOptionErrorMessage(Map.empty, "'rowTag' option is required.")
+    checkXmlOptionErrorMessage(Map("rowTag" -> ""),
+      "'rowTag' option should not be an empty string.")
+    checkXmlOptionErrorMessage(Map("rowTag" -> " "),
+      "'rowTag' option should not be an empty string.")
+    checkXmlOptionErrorMessage(Map("rowTag" -> "person",
+      "declaration" -> s"<${XmlOptions.DEFAULT_DECLARATION}>"),
+      "'declaration' should not include angle brackets")
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlGeneratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlGeneratorSuite.scala
index 176cfd98563..1798d32d8a2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlGeneratorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/parsers/StaxXmlGeneratorSuite.scala
@@ -69,9 +69,9 @@ final class StaxXmlGeneratorSuite extends SharedSparkSession {
     val df = dataset.toDF().orderBy("booleanDatum")
     val targetFile =
       Files.createTempDirectory("StaxXmlGeneratorSuite").resolve("roundtrip.xml").toString
-    df.write.format("xml").save(targetFile)
+    df.write.option("rowTag", "ROW").xml(targetFile)
     val newDf =
-      spark.read.schema(df.schema).format("xml").load(targetFile).orderBy("booleanDatum")
+      spark.read.option("rowTag", "ROW").schema(df.schema).xml(targetFile).orderBy("booleanDatum")
     assert(df.collect().toSeq === newDf.collect().toSeq)
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org