You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2024/03/20 01:16:55 UTC

(spark) branch master updated: [SPARK-47309][SQL] XML: Add schema inference tests for value tags

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b9c0e9331935 [SPARK-47309][SQL] XML: Add schema inference tests for value tags
b9c0e9331935 is described below

commit b9c0e93319350caa4beecdcd42051449ec1f9c08
Author: Shujing Yang <sh...@databricks.com>
AuthorDate: Wed Mar 20 10:16:43 2024 +0900

    [SPARK-47309][SQL] XML: Add schema inference tests for value tags
    
    ### What changes were proposed in this pull request?
    
    Add schema inference tags for corrupt records, null values and value tags. For value tags, this PR adds the following tests:
    1. Conflict between primitive types conflict
    2. Root-level value tag
    3. empty value tag in some rows
    4. array of value tags:
       1) values split into multiple lines
       2) interspersed in nested structs: empty fields and optional fields in structs
       3) interspersed in arrays and value tags:  empty fields and optional fields in structs
       4) name conflict
       5) CDATA and comments
       6) no spaces / some spaces / whitespaces between valueTags and elements
    
    ### Why are the changes needed?
    
    This is a test-only change.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    This is a test-only change.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #45538 from shujingyang-db/xml-inference-test.
    
    Lead-authored-by: Shujing Yang <sh...@databricks.com>
    Co-authored-by: Shujing Yang <13...@users.noreply.github.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 .../spark/sql/catalyst/xml/XmlInferSchema.scala    |   8 +-
 .../execution/datasources/xml/TestXmlData.scala    | 269 ++++++++++++++++
 .../datasources/xml/XmlInferSchemaSuite.scala      | 338 ++++++++++++++++++++-
 3 files changed, 613 insertions(+), 2 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala
index b9342c53d020..4640f86d5997 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/XmlInferSchema.scala
@@ -37,7 +37,6 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion
 import org.apache.spark.sql.catalyst.expressions.ExprUtils
 import org.apache.spark.sql.catalyst.util.{DateFormatter, DropMalformedMode, FailFastMode, ParseMode, PermissiveMode, TimestampFormatter}
 import org.apache.spark.sql.catalyst.util.LegacyDateFormats.FAST_DATE_FORMAT
-import org.apache.spark.sql.catalyst.xml.XmlInferSchema.compatibleType
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
 import org.apache.spark.sql.types._
@@ -46,6 +45,8 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean)
     extends Serializable
     with Logging {
 
+  import org.apache.spark.sql.catalyst.xml.XmlInferSchema._
+
   private val decimalParser = ExprUtils.getDecimalParser(options.locale)
 
   private val timestampFormatter = TimestampFormatter(
@@ -120,6 +121,7 @@ class XmlInferSchema(options: XmlOptions, caseSensitive: Boolean)
       case Some(st: StructType) => st
       case _ =>
         // canonicalizeType erases all empty structs, including the only one we want to keep
+        // XML shouldn't run into this line
         StructType(Seq())
     }
   }
@@ -541,6 +543,10 @@ object XmlInferSchema {
         // As this library can infer an element with attributes as StructType whereas
         // some can be inferred as other non-structural data types, this case should be
         // treated.
+        // 1. Without value tags, combining structs and primitive types defaults to string type
+        // 2. With value tags, combining structs and primitive types defaults to
+        //    a struct with value tags of compatible type
+        // This behavior keeps aligned with JSON
         case (st: StructType, dt: DataType) if st.fieldNames.contains(valueTag) =>
           val valueIndex = st.fieldNames.indexOf(valueTag)
           val valueField = st.fields(valueIndex)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/TestXmlData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/TestXmlData.scala
index 704a02482ada..616ccda62fc2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/TestXmlData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/TestXmlData.scala
@@ -361,6 +361,52 @@ private[xml] trait TestXmlData {
           |</ROW>
           |""".stripMargin :: Nil
 
+  def nullsInArrays: Seq[String] =
+    """<ROW>
+         <field1>
+            <array1>
+              <array2>value1</array2>
+              <array2>value2</array2>
+            </array1>
+            <array1/>
+          </field1>
+          <field1/>
+        </ROW>""" ::
+    """
+        <ROW>
+          <field2/>
+          <field2>
+            <array1>
+              <Test>1</Test>
+            </array1>
+            <array1/>
+          </field2>
+        </ROW>""" ::
+    """
+        <ROW>
+          <field1/>
+          <field1><array1/></field1>
+          <field2/>
+        </ROW>""" :: Nil
+
+  def corruptRecords: Seq[String] =
+    """<ROW>""" ::
+    """""" ::
+    """<ROW>
+        |  <a>1</a>
+        |  <b>2</b>
+        |</ROW>""".stripMargin ::
+    """
+        |<ROW>
+        |  <a>str_a_4</a>
+        |  <b>str_b_4</b>
+        |  <c>str_c_4</c>
+        |</ROW>
+        |""".stripMargin ::
+    """
+        |</ROW>
+        |""".stripMargin :: Nil
+
   def emptyRecords: Seq[String] =
     """<ROW>
           <a><struct></struct></a>
@@ -378,4 +424,227 @@ private[xml] trait TestXmlData {
             <item/>
           </b>
         </ROW>""" :: Nil
+
+  def arrayAndStructRecords: Seq[String] =
+    """<ROW>
+          <a>
+            <b>1</b>
+          </a>
+        </ROW>""" ::
+    """<ROW>
+          <a><item/><item/></a>
+        </ROW>""" ::
+    Nil
+
+  def valueTagsTypeConflict: Seq[String] =
+    """
+      |<ROW>
+      |    13.1
+      |    <a>
+      |        11
+      |        <b>
+      |            true
+      |            <c>1</c>
+      |        </b>
+      |    </a>
+      |    string
+      |</ROW>
+      |""".stripMargin ::
+    """
+      |<ROW>
+      |    string
+      |    <a>
+      |        21474836470
+      |        <b>
+      |            false
+      |            <c>2</c>
+      |        </b>
+      |    </a>
+      |    true
+      |</ROW>
+      |""".stripMargin ::
+    """
+        |<ROW>
+        |<a>
+        |    <b>
+        |        12
+        |        <c>3</c>
+        |    </b>
+        |</a>
+        |92233720368547758070
+        |</ROW>
+        |""".stripMargin :: Nil
+
+  val emptyValueTags: Seq[String] =
+    """
+      |<ROW>
+      |    str1
+      |    <a>  <b>1</b>
+      |    </a>str2
+      |</ROW>
+      |""".stripMargin ::
+    """<ROW> <a><b/> value</a></ROW>""" ::
+    """<ROW><a><b>3</b> </a> </ROW>""" ::
+    """<ROW><a><b>4</b> </a>
+      |    str3
+      |</ROW>""".stripMargin :: Nil
+
+  val multilineValueTags =
+    """
+      |<ROW>
+      |    value1
+      |    <a>1</a>
+      |    value2
+      |</ROW>
+      |""".stripMargin ::
+    """
+      |<ROW>
+      |    value3
+      |    value4<a>1</a>
+      |</ROW>
+      |""".stripMargin :: Nil
+
+  val valueTagsAroundStructs =
+    """
+      |<ROW>
+      |    value1
+      |    <a>
+      |        value2
+      |        <b>
+      |            3
+      |            <c>1</c>
+      |        </b>
+      |        value4
+      |    </a>
+      |    value5
+      |</ROW>
+      |""".stripMargin ::
+    """
+      |<ROW>
+      |    value1
+      |    <a>
+      |        value2
+      |        <b>3</b>
+      |        value4
+      |    </a>
+      |</ROW>
+      |""".stripMargin ::
+  """
+      |<ROW>
+      |    <a>
+      |        <b></b>
+      |        value4
+      |        <!--First comment-->
+      |        value5
+      |    </a>
+      |    value6
+      |</ROW>
+      |""".stripMargin ::
+    """
+      |<ROW>
+      |    value1
+      |    <a>
+      |        value2
+      |        <b>
+      |            3
+      |            <c/>
+      |        </b>
+      |        value4
+      |    </a>
+      |    value5
+      |</ROW>
+      |""".stripMargin :: Nil
+
+  val valueTagsAroundArrays =
+    """
+      |<ROW>
+      |    value1
+      |    <array1>
+      |        value2
+      |        <array2>
+      |          1
+      |          <num>1</num>
+      |          2
+      |        </array2>
+      |        value3
+      |        <!--First comment--> <!--Second comment-->
+      |        value4<!--Third comment-->
+      |        value5
+      |        <array2>2</array2>value6
+      |        value7
+      |    </array1>
+      |    value8
+      |    <array1>
+      |        value9
+      |        <array2> <!--First comment--><num>2</num></array2>
+      |        value10
+      |        <array2></array2>
+      |        <array2> <!--First comment-->
+      |        <!--Second comment--></array2>
+      |        <array2>3</array2>
+      |        value11
+      |    </array1>
+      |    value12
+      |    <!--First comment-->
+      |    value13
+      |</ROW>
+      |""".stripMargin ::
+    """
+      |<ROW>
+      |    <array1>
+      |        value1
+      |    </array1>
+      |</ROW>
+      |""".stripMargin ::
+    """
+      |<ROW>
+      |    <array1>
+      |        <array2>
+      |            1
+      |        </array2>
+      |    </array1>
+      |    value1
+      |</ROW>
+      |""".stripMargin :: Nil
+
+  val valueTagConflictName =
+    """<ROW>
+      |    <a>1</a>
+      |    2
+      |</ROW>""".stripMargin :: Nil
+
+  val valueTagWithComments =
+    """
+      |<ROW>
+      |    <!--First comment-->
+      |    <!--Second comment-->
+      |    <a><!--First comment--></a>
+      |    <a attr="1"><!--First comment--> <!--Second comment--></a>
+      |    2
+      |</ROW>
+      |""".stripMargin :: Nil
+
+  val valueTagWithCDATA =
+    """
+      |<ROW>
+      |    <![CDATA[This is a CDATA section containing <sample1> text.]]>
+      |    <a>
+      |        <![CDATA[This is a CDATA section containing <sample2> text.]]>
+      |        <![CDATA[This is a CDATA section containing <sample3> text.]]>
+      |        <b>1</b>
+      |        <![CDATA[This is a CDATA section containing <sample4> text.]]>
+      |
+      |        <b>2</b>
+      |
+      |    </a>
+      |
+      |</ROW>
+      |""".stripMargin :: Nil
+
+  val valueTagIsNullValue =
+    """
+      |<ROW>
+      |    1
+      |</ROW>
+      |""".stripMargin :: Nil
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlInferSchemaSuite.scala
index 697bd3d8b824..286120ff40b8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlInferSchemaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlInferSchemaSuite.scala
@@ -16,13 +16,19 @@
  */
 package org.apache.spark.sql.execution.datasources.xml
 
+import java.io.File
+
+import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Encoders, QueryTest, Row}
+import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types.{
   ArrayType,
   BooleanType,
   DecimalType,
   DoubleType,
+  IntegerType,
   LongType,
   StringType,
   StructField,
@@ -31,7 +37,13 @@ import org.apache.spark.sql.types.{
 
 class XmlInferSchemaSuite extends QueryTest with SharedSparkSession with TestXmlData {
 
-  val baseOptions = Map("rowTag" -> "ROW")
+  private val baseOptions = Map("rowTag" -> "ROW")
+
+  private val ignoreSurroundingSpacesOptions = Map("ignoreSurroundingSpaces" -> "true")
+
+  private val notIgnoreSurroundingSpacesOptions = Map("ignoreSurroundingSpaces" -> "false")
+
+  private val valueTagName = "_VALUE"
 
   def readData(xmlString: Seq[String], options: Map[String, String] = Map.empty): DataFrame = {
     val dataset = spark.createDataset(spark.sparkContext.parallelize(xmlString))(Encoders.STRING)
@@ -293,4 +305,328 @@ class XmlInferSchemaSuite extends QueryTest with SharedSparkSession with TestXml
     assert(emptyDF.schema === expectedSchema)
   }
 
+  test("nulls in arrays") {
+    val expectedSchema = StructType(
+      StructField(
+        "field1",
+        ArrayType(
+          new StructType()
+            .add("array1", ArrayType(new StructType().add("array2", ArrayType(StringType))))
+        )
+      ) ::
+      StructField(
+        "field2",
+        ArrayType(
+          new StructType()
+            .add("array1", ArrayType(StructType(StructField("Test", LongType) :: Nil)))
+        )
+      ) :: Nil
+    )
+    val expectedAns = Seq(
+      Row(Seq(Row(Seq(Row(Seq("value1", "value2")), Row(null))), Row(null)), null),
+      Row(null, Seq(Row(null), Row(Seq(Row(1), Row(null))))),
+      Row(Seq(Row(null), Row(Seq(Row(null)))), Seq(Row(null)))
+    )
+    val xmlDF = readData(nullsInArrays)
+    assert(xmlDF.schema === expectedSchema)
+    checkAnswer(xmlDF, expectedAns)
+  }
+
+  test("corrupt records: fail fast mode") {
+    // fail fast mode is covered in the testcase: DSL test for failing fast in XmlSuite
+    val schemaOne = StructType(
+      StructField("a", StringType, true) ::
+      StructField("b", StringType, true) ::
+      StructField("c", StringType, true) :: Nil
+    )
+    // `DROPMALFORMED` mode should skip corrupt records
+    val xmlDFOne = readData(corruptRecords, Map("mode" -> "DROPMALFORMED"))
+    checkAnswer(
+      xmlDFOne,
+      Row("1", "2", null) ::
+      Row("str_a_4", "str_b_4", "str_c_4") :: Nil
+    )
+    assert(xmlDFOne.schema === schemaOne)
+  }
+
+  test("turn non-nullable schema into a nullable schema") {
+    // XML field is missing.
+    val missingFieldInput = """<ROW><c1>1</c1></ROW>"""
+    val missingFieldInputDS =
+      spark.createDataset(spark.sparkContext.parallelize(missingFieldInput :: Nil))(Encoders.STRING)
+    // XML filed is null.
+    val nullValueInput = """<ROW><c1>1</c1><c2/></ROW>"""
+    val nullValueInputDS =
+      spark.createDataset(spark.sparkContext.parallelize(nullValueInput :: Nil))(Encoders.STRING)
+
+    val schema = StructType(
+      Seq(
+        StructField("c1", IntegerType, nullable = false),
+        StructField("c2", IntegerType, nullable = false)
+      )
+    )
+    val expected = schema.asNullable
+
+    Seq(missingFieldInputDS, nullValueInputDS).foreach { xmlStringDS =>
+      Seq("DROPMALFORMED", "FAILFAST", "PERMISSIVE").foreach { mode =>
+        val df = spark.read
+          .option("mode", mode)
+          .option("rowTag", "ROW")
+          .schema(schema)
+          .xml(xmlStringDS)
+        assert(df.schema == expected)
+        checkAnswer(df, Row(1, null) :: Nil)
+      }
+      withSQLConf(SQLConf.LEGACY_RESPECT_NULLABILITY_IN_TEXT_DATASET_CONVERSION.key -> "true") {
+        checkAnswer(
+          spark.read
+            .schema(
+              StructType(
+                StructField("c1", LongType, nullable = false) ::
+                StructField("c2", LongType, nullable = false) :: Nil
+              )
+            )
+            .option("rowTag", "ROW")
+            .option("mode", "DROPMALFORMED")
+            .xml(xmlStringDS),
+          // It is for testing legacy configuration. This is technically a bug as
+          // `0` has to be `null` but the schema is non-nullable.
+          Row(1, 0)
+        )
+      }
+    }
+  }
+
+  test("XML with partitions") {
+    def makePartition(rdd: RDD[String], parent: File, partName: String, partValue: Any): File = {
+      val p = new File(parent, s"$partName=${partValue.toString}")
+      rdd.saveAsTextFile(p.getCanonicalPath)
+      p
+    }
+
+    withTempPath(root => {
+      withTempView("test_myxml_with_part") {
+        val d1 = new File(root, "d1=1")
+        // root/d1=1/col1=abc
+        makePartition(
+          sparkContext.parallelize(2 to 5).map(i => s"""<ROW><a>1</a><b>str$i</b></ROW>"""),
+          d1,
+          "col1",
+          "abc"
+        )
+
+        // root/d1=1/col1=abd
+        makePartition(
+          sparkContext.parallelize(6 to 10).map(i => s"""<ROW><a>1</a><c>str$i</c></ROW>"""),
+          d1,
+          "col1",
+          "abd"
+        )
+        val expectedSchema = new StructType()
+          .add("a", LongType)
+          .add("b", StringType)
+          .add("c", StringType)
+          .add("d1", IntegerType)
+          .add("col1", StringType)
+
+        val df = spark.read.option("rowTag", "ROW").xml(root.getAbsolutePath)
+        assert(df.schema === expectedSchema)
+        assert(df.where(col("d1") === 1).where(col("col1") === "abc").select("a").count() == 4)
+        assert(df.where(col("d1") === 1).where(col("col1") === "abd").select("a").count() == 5)
+        assert(df.where(col("d1") === 1).select("a").count() == 9)
+      }
+    })
+  }
+
+  test("value tag - type conflict and root level value tags") {
+    val xmlDF = readData(valueTagsTypeConflict, ignoreSurroundingSpacesOptions)
+    val expectedSchema = new StructType()
+      .add(valueTagName, ArrayType(StringType))
+      .add(
+        "a",
+        new StructType()
+          .add(valueTagName, LongType)
+          .add("b", new StructType().add(valueTagName, StringType).add("c", LongType))
+      )
+    assert(xmlDF.schema == expectedSchema)
+    val expectedAns = Seq(
+      Row(Seq("13.1", "string"), Row(11, Row("true", 1))),
+      Row(Seq("string", "true"), Row(21474836470L, Row("false", 2))),
+      Row(Seq("92233720368547758070"), Row(null, Row("12", 3)))
+    )
+    checkAnswer(xmlDF, expectedAns)
+  }
+
+  test("value tag - spaces and empty values") {
+    val expectedSchema = new StructType()
+      .add(valueTagName, ArrayType(StringType))
+      .add("a", new StructType().add(valueTagName, StringType).add("b", LongType))
+    // even though we don't ignore the surrounding spaces of characters,
+    // we won't put whitespaces as value tags :)
+    val xmlDFWSpaces =
+      readData(emptyValueTags, notIgnoreSurroundingSpacesOptions)
+    val xmlDFWOSpaces = readData(emptyValueTags, ignoreSurroundingSpacesOptions)
+    assert(xmlDFWSpaces.schema == expectedSchema)
+    assert(xmlDFWOSpaces.schema == expectedSchema)
+
+    val expectedAnsWSpaces = Seq(
+      Row(Seq("\n    str1\n    ", "str2\n"), Row(null, 1)),
+      Row(null, Row(" value", null)),
+      Row(null, Row(null, 3)),
+      Row(Seq("\n    str3\n"), Row(null, 4))
+    )
+    checkAnswer(xmlDFWSpaces, expectedAnsWSpaces)
+    val expectedAnsWOSpaces = Seq(
+      Row(Seq("str1", "str2"), Row(null, 1)),
+      Row(null, Row("value", null)),
+      Row(null, Row(null, 3)),
+      Row(Seq("str3"), Row(null, 4))
+    )
+    checkAnswer(xmlDFWOSpaces, expectedAnsWOSpaces)
+  }
+
+  test("value tags - multiple lines") {
+    val xmlDF = readData(multilineValueTags, ignoreSurroundingSpacesOptions)
+    val expectedSchema =
+      new StructType().add(valueTagName, ArrayType(StringType)).add("a", LongType)
+    val expectedAns = Seq(
+      Row(Seq("value1", "value2"), 1),
+      Row(Seq("value3\n    value4"), 1)
+    )
+    assert(xmlDF.schema == expectedSchema)
+    checkAnswer(xmlDF, expectedAns)
+  }
+
+  test("value tags - around structs") {
+    val xmlDF = readData(valueTagsAroundStructs)
+    val expectedSchema = new StructType()
+      .add(valueTagName, ArrayType(StringType))
+      .add(
+        "a",
+        new StructType()
+          .add(valueTagName, ArrayType(StringType))
+          .add("b", new StructType().add(valueTagName, LongType).add("c", LongType))
+      )
+
+    assert(xmlDF.schema == expectedSchema)
+    val expectedAns = Seq(
+      Row(
+        Seq("value1", "value5"),
+        Row(Seq("value2", "value4"), Row(3, 1))
+      ),
+      Row(
+        Seq("value6"),
+        Row(Seq("value4", "value5"), Row(null, null))
+      ),
+      Row(
+        Seq("value1", "value5"),
+        Row(Seq("value2", "value4"), Row(3, null))
+      ),
+      Row(
+        Seq("value1"),
+        Row(Seq("value2", "value4"), Row(3, null))
+      )
+    )
+    checkAnswer(xmlDF, expectedAns)
+  }
+
+  test("value tags - around arrays") {
+    val xmlDF = readData(valueTagsAroundArrays)
+    val expectedSchema = new StructType()
+      .add(valueTagName, ArrayType(StringType))
+      .add(
+        "array1",
+        ArrayType(
+          new StructType()
+            .add(valueTagName, ArrayType(StringType))
+            .add(
+              "array2",
+              ArrayType(new StructType()
+                // The value tag is not of long type due to:
+                // 1. when we infer the type for the array2 in the second array1,
+                // it combines a struct type and a primitive type and results in a string type
+                // 2. when we merge the inferred type for the first array2 and the second,
+                // we are merging a struct with longtype value tag and a string type.
+                // It results in merging the long type value tag with the primitive type
+                // and thus finally we got a struct with string type value tag.
+                .add(valueTagName, ArrayType(StringType))
+                .add("num", LongType)))))
+    assert(xmlDF.schema === expectedSchema)
+    val expectedAns = Seq(
+      Row(
+        Seq("value1", "value8", "value12", "value13"),
+        Seq(
+          Row(
+            Seq("value2", "value3", "value4", "value5", "value6\n        value7"),
+            Seq(Row(Seq("1", "2"), 1), Row(Seq("2"), null))),
+          Row(
+            Seq("value9", "value10", "value11"),
+            Seq(Row(null, 2), Row(null, null), Row(null, null), Row(Seq("3"), null))))),
+      Row(
+        null,
+        Seq(
+          Row(
+            Seq("value1"), null))),
+      Row(
+        Seq("value1"),
+        Seq(
+          Row(
+            null,
+            Seq(Row(Seq("1"), null))))))
+    checkAnswer(xmlDF, expectedAns)
+  }
+
+  test("value tag - user specifies a conflicting name for valueTag") {
+    val xmlDF = readData(valueTagConflictName, Map("valueTag" -> "a"))
+    val expectedSchema = new StructType().add("a", ArrayType(LongType))
+    assert(xmlDF.schema == expectedSchema)
+    checkAnswer(xmlDF, Seq(Row(Seq(1, 2))))
+  }
+
+  test("value tag - comments") {
+    val xmlDF = readData(valueTagWithComments)
+    val expectedSchema = new StructType()
+      .add(valueTagName, LongType)
+      .add("a", ArrayType(new StructType().add("_attr", LongType)))
+    val expectedAns = Seq(
+      Row(2, Seq(Row(null), Row(1))))
+    assert(xmlDF.schema === expectedSchema)
+    checkAnswer(xmlDF, expectedAns)
+  }
+
+  test("value tags - CDATA") {
+    val xmlDF = readData(valueTagWithCDATA)
+    val expectedSchema = new StructType()
+      .add(valueTagName, StringType)
+      .add("a", new StructType()
+        .add(valueTagName, ArrayType(StringType))
+        .add("b", ArrayType(LongType)))
+
+    val expectedAns = Seq(
+      Row(
+        "This is a CDATA section containing <sample1> text.",
+        Row(
+          Seq(
+            "This is a CDATA section containing <sample2> text.\n" +
+              "        This is a CDATA section containing <sample3> text.",
+            "This is a CDATA section containing <sample4> text."
+          ),
+          Seq(1, 2)
+        )
+      )
+    )
+    assert(xmlDF.schema === expectedSchema)
+    checkAnswer(xmlDF, expectedAns)
+  }
+
+  test("value tag - equals to null value") {
+    // we don't consider options.nullValue during schema inference
+    val xmlDF = readData(valueTagIsNullValue, Map("nullValue" -> "1"))
+    val expectedSchema = new StructType()
+      .add(valueTagName, LongType)
+    val expectedAns = Seq(Row(null))
+    // nullValue option is used during parsing
+    assert(xmlDF.schema === expectedSchema)
+    checkAnswer(xmlDF, expectedAns)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org