You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2024/03/12 00:24:47 UTC
(spark) branch master updated: [SPARK-47309][SQL][XML] Add schema inference unit tests
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 63c01c8036e8 [SPARK-47309][SQL][XML] Add schema inference unit tests
63c01c8036e8 is described below
commit 63c01c8036e8eb672bcca8bc2df994882a1b1727
Author: Shujing Yang <sh...@databricks.com>
AuthorDate: Tue Mar 12 09:24:36 2024 +0900
[SPARK-47309][SQL][XML] Add schema inference unit tests
### What changes were proposed in this pull request?
As titled.
### Why are the changes needed?
Fix a bug.
### Does this PR introduce _any_ user-facing change?
Yes
### How was this patch tested?
Unit tests
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #45411 from shujingyang-db/xml-inference-check.
Authored-by: Shujing Yang <sh...@databricks.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
.../execution/datasources/xml/TestXmlData.scala | 310 +++++++++++++++++++++
.../datasources/xml/XmlInferSchemaSuite.scala | 296 ++++++++++++++++++++
2 files changed, 606 insertions(+)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/TestXmlData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/TestXmlData.scala
index abcf8c7cdd72..704a02482ada 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/TestXmlData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/TestXmlData.scala
@@ -68,4 +68,314 @@ private[xml] trait TestXmlData {
f(dir)
fs.setVerifyChecksum(true)
}
+
+ def primitiveFieldValueTypeConflict: Seq[String] =
+ """<ROW>
+ | <num_num_1>11</num_num_1>
+ | <num_num_2/>
+ | <num_num_3>1.1</num_num_3>
+ | <num_bool>true</num_bool>
+ | <num_str>13.1</num_str>
+ | <str_bool>str1</str_bool>
+ |</ROW>
+ |""".stripMargin ::
+ """
+ |<ROW>
+ | <num_num_1/>
+ | <num_num_2>21474836470.9</num_num_2>
+ | <num_num_3/>
+ | <num_bool>12</num_bool>
+ | <num_str/>
+ | <str_bool>true</str_bool>
+ |</ROW>""".stripMargin ::
+ """
+ |<ROW>
+ | <num_num_1>21474836470</num_num_1>
+ | <num_num_2>92233720368547758070</num_num_2>
+ | <num_num_3>100</num_num_3>
+ | <num_bool>false</num_bool>
+ | <num_str>str1</num_str>
+ | <str_bool>false</str_bool>
+ |</ROW>""".stripMargin ::
+ """
+ |<ROW>
+ | <num_num_1>21474836570</num_num_1>
+ | <num_num_2>1.1</num_num_2>
+ | <num_num_3>21474836470</num_num_3>
+ | <num_bool/>
+ | <num_str>92233720368547758070</num_str>
+ | <str_bool/>
+ |</ROW>""".stripMargin :: Nil
+
+ def xmlNullStruct: Seq[String] =
+ """<ROW>
+ | <nullstr></nullstr>
+ | <ip>27.31.100.29</ip>
+ | <headers>
+ | <Host>1.abc.com</Host>
+ | <Charset>UTF-8</Charset>
+ | </headers>
+ |</ROW>""".stripMargin ::
+ """<ROW>
+ | <nullstr></nullstr>
+ | <ip>27.31.100.29</ip>
+ | <headers/>
+ |</ROW>""".stripMargin ::
+ """<ROW>
+ | <nullstr></nullstr>
+ | <ip>27.31.100.29</ip>
+ | <headers></headers>
+ |</ROW>""".stripMargin ::
+ """<ROW>
+ | <nullstr/>
+ | <ip>27.31.100.29</ip>
+ | <headers/>
+ |</ROW>""".stripMargin :: Nil
+
+ def complexFieldValueTypeConflict: Seq[String] =
+ """<ROW>
+ <num_struct>11</num_struct>
+ <str_array>1</str_array>
+ <str_array>2</str_array>
+ <str_array>3</str_array>
+ <array></array>
+ <struct_array></struct_array>
+ <struct></struct>
+ </ROW>""" ::
+ """<ROW>
+ <num_struct>
+ <field>false</field>
+ </num_struct>
+ <str_array/>
+ <array/>
+ <struct_array></struct_array>
+ <struct/>
+ </ROW>""" ::
+ """<ROW>
+ <num_struct/>
+ <str_array>str</str_array>
+ <array>4</array>
+ <array>5</array>
+ <array>6</array>
+ <struct_array>7</struct_array>
+ <struct_array>8</struct_array>
+ <struct_array>9</struct_array>
+ <struct>
+ <field/>
+ </struct>
+ </ROW>""" ::
+ """<ROW>
+ <num_struct></num_struct>
+ <str_array>str1</str_array>
+ <str_array>str2</str_array>
+ <str_array>33</str_array>
+ <array>7</array>
+ <struct_array>
+ <field>true</field>
+ </struct_array>
+ <struct>
+ <field>str</field>
+ </struct>
+ </ROW>""" :: Nil
+
+ def arrayElementTypeConflict: Seq[String] =
+ """
+ |<ROW>
+ | <array1>
+ | <element>1</element>
+ | <element>1.1</element>
+ | <element>true</element>
+ | <element/>
+ | <element>
+ | <array/>
+ | </element>
+ | <element>
+ | <object/>
+ | </element>
+ | </array1>
+ | <array1>
+ | <element>
+ | <array>
+ | <element>2</element>
+ | <element>3</element>
+ | <element>4</element>
+ | </array>
+ | </element>
+ | <element>
+ | <object>
+ | <field>str</field>
+ | </object>
+ | </element>
+ | </array1>
+ | <array2>
+ | <field>214748364700</field>
+ | </array2>
+ | <array2>
+ | <field>1</field>
+ | </array2>
+ |</ROW>
+ |""".stripMargin ::
+ """
+ |<ROW>
+ | <array3>
+ | <field>str</field>
+ | </array3>
+ | <array3>
+ | <field>1</field>
+ | </array3>
+ |</ROW>
+ |""".stripMargin ::
+ """
+ |<ROW>
+ | <array3>1</array3>
+ | <array3>2</array3>
+ | <array3>3</array3>
+ |</ROW>
+ |""".stripMargin :: Nil
+
+ def missingFields: Seq[String] =
+ """
+ <ROW><a>true</a></ROW>
+ """ ::
+ """
+ <ROW><b>21474836470</b></ROW>
+ """ ::
+ """
+ <ROW><c>33</c><c>44</c></ROW>
+ """ ::
+ """
+ <ROW><d><field>true</field></d></ROW>
+ """ ::
+ """
+ <ROW><e>str</e></ROW>
+ """ :: Nil
+
+ // XML doesn't support array of arrays
+ // It only supports array of structs
+ def complexFieldAndType1: Seq[String] =
+ """
+ |<ROW>
+ | <struct>
+ | <field1>true</field1>
+ | <field2>92233720368547758070</field2>
+ | </struct>
+ | <structWithArrayFields>
+ | <field1>4</field1>
+ | <field1>5</field1>
+ | <field1>6</field1>
+ | <field2>str1</field2>
+ | <field2>str2</field2>
+ | </structWithArrayFields>
+ | <arrayOfString>str1</arrayOfString>
+ | <arrayOfString>str2</arrayOfString>
+ | <arrayOfInteger>1</arrayOfInteger>
+ | <arrayOfInteger>2147483647</arrayOfInteger>
+ | <arrayOfInteger>-2147483648</arrayOfInteger>
+ | <arrayOfLong>21474836470</arrayOfLong>
+ | <arrayOfLong>9223372036854775807</arrayOfLong>
+ | <arrayOfLong>-9223372036854775808</arrayOfLong>
+ | <arrayOfBigInteger>922337203685477580700</arrayOfBigInteger>
+ | <arrayOfBigInteger>-922337203685477580800</arrayOfBigInteger>
+ | <arrayOfDouble>1.2</arrayOfDouble>
+ | <arrayOfDouble>1.7976931348623157</arrayOfDouble>
+ | <arrayOfDouble>4.9E-324</arrayOfDouble>
+ | <arrayOfDouble>2.2250738585072014E-308</arrayOfDouble>
+ | <arrayOfBoolean>true</arrayOfBoolean>
+ | <arrayOfBoolean>false</arrayOfBoolean>
+ | <arrayOfBoolean>true</arrayOfBoolean>
+ | <arrayOfNull></arrayOfNull>
+ | <arrayOfNull></arrayOfNull>
+ | <arrayOfStruct>
+ | <field1>true</field1>
+ | <field2>str1</field2>
+ | </arrayOfStruct>
+ | <arrayOfStruct>
+ | <field1>false</field1>
+ | </arrayOfStruct>
+ | <arrayOfStruct>
+ | <field3/>
+ | </arrayOfStruct>
+ |<arrayOfArray1>
+ | <item>1</item><item>2</item><item>3</item>
+ |</arrayOfArray1>
+ |<arrayOfArray1>
+ | <item>str1</item><item>str2</item>
+ |</arrayOfArray1>
+ |<arrayOfArray2>
+ | <item>1</item><item>2</item><item>3</item>
+ |</arrayOfArray2>
+ |<arrayOfArray2>
+ | <item>1.1</item><item>2.1</item><item>3.1</item>
+ |</arrayOfArray2>
+ |</ROW>
+ |
+ |""".stripMargin :: Nil
+
+ def complexFieldAndType2: Seq[String] =
+ """
+ |<ROW>
+ | <arrayOfArray1>
+ | <array>
+ | <item>5</item>
+ | </array>
+ |</arrayOfArray1>
+ |<arrayOfArray1>
+ | <array>
+ | <item>6</item><item>7</item>
+ | </array>
+ | <array>
+ | <item>8</item>
+ | </array>
+ |</arrayOfArray1>
+ | <arrayOfArray2>
+ | <array>
+ | <item>
+ | <inner1>str1</inner1>
+ | </item>
+ | </array>
+ |</arrayOfArray2>
+ |<arrayOfArray2>
+ | <array/>
+ | <array>
+ | <item>
+ | <inner2>str3</inner2>
+ | <inner2>str33</inner2>
+ | </item>
+ | <item>
+ | <inner2>str4</inner2>
+ | <inner1>str11</inner1>
+ | </item>
+ | </array>
+ |</arrayOfArray2>
+ |<arrayOfArray2>
+ | <array>
+ | <item>
+ | <inner3>
+ | <inner4>2</inner4>
+ | <inner4>3</inner4>
+ | </inner3>
+ | <inner3/>
+ | </item>
+ | </array>
+ |</arrayOfArray2>
+ |</ROW>
+ |""".stripMargin :: Nil
+
+ def emptyRecords: Seq[String] =
+ """<ROW>
+ <a><struct></struct></a>
+ </ROW>""" ::
+ """<ROW>
+ <a>
+ <struct><b><c/></b></struct>
+ </a>
+ </ROW>""" ::
+ """<ROW>
+ <b>
+ <item>
+ <c><struct></struct></c>
+ </item>
+ <item/>
+ </b>
+ </ROW>""" :: Nil
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlInferSchemaSuite.scala
new file mode 100644
index 000000000000..697bd3d8b824
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlInferSchemaSuite.scala
@@ -0,0 +1,296 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.xml
+
+import org.apache.spark.sql.{DataFrame, Encoders, QueryTest, Row}
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{
+ ArrayType,
+ BooleanType,
+ DecimalType,
+ DoubleType,
+ LongType,
+ StringType,
+ StructField,
+ StructType
+}
+
+class XmlInferSchemaSuite extends QueryTest with SharedSparkSession with TestXmlData {
+
+ val baseOptions = Map("rowTag" -> "ROW")
+
+ def readData(xmlString: Seq[String], options: Map[String, String] = Map.empty): DataFrame = {
+ val dataset = spark.createDataset(spark.sparkContext.parallelize(xmlString))(Encoders.STRING)
+ spark.read.options(baseOptions ++ options).xml(dataset)
+ }
+
+ // TODO: add tests for type widening
+ test("Type conflict in primitive field values") {
+ val xmlDF = readData(primitiveFieldValueTypeConflict, Map("nullValue" -> ""))
+ val expectedSchema = StructType(
+ StructField("num_bool", StringType, true) ::
+ StructField("num_num_1", LongType, true) ::
+ StructField("num_num_2", DoubleType, true) ::
+ StructField("num_num_3", DoubleType, true) ::
+ StructField("num_str", StringType, true) ::
+ StructField("str_bool", StringType, true) :: Nil
+ )
+ val expectedAns = Row("true", 11L, null, 1.1, "13.1", "str1") ::
+ Row("12", null, 21474836470.9, null, null, "true") ::
+ Row("false", 21474836470L, 92233720368547758070d, 100, "str1", "false") ::
+ Row(null, 21474836570L, 1.1, 21474836470L, "92233720368547758070", null) :: Nil
+ assert(expectedSchema == xmlDF.schema)
+ checkAnswer(xmlDF, expectedAns)
+ }
+
+ test("Type conflict in complex field values") {
+ val xmlDF = readData(
+ complexFieldValueTypeConflict,
+ Map("nullValue" -> "", "ignoreSurroundingSpaces" -> "true")
+ )
+ // XML will merge an array and a singleton into an array
+ val expectedSchema = StructType(
+ StructField("array", ArrayType(LongType, true), true) ::
+ StructField("num_struct", StringType, true) ::
+ StructField("str_array", ArrayType(StringType), true) ::
+ StructField("struct", StructType(StructField("field", StringType, true) :: Nil), true) ::
+ StructField("struct_array", ArrayType(StringType), true) :: Nil
+ )
+
+ assert(expectedSchema === xmlDF.schema)
+ checkAnswer(
+ xmlDF,
+ Row(Seq(null), "11", Seq("1", "2", "3"), Row(null), Seq(null)) ::
+ Row(Seq(null), """<field>false</field>""", Seq(null), Row(null), Seq(null)) ::
+ Row(Seq(4, 5, 6), null, Seq("str"), Row(null), Seq("7", "8", "9")) ::
+ Row(Seq(7), null, Seq("str1", "str2", "33"), Row("str"), Seq("""<field>true</field>""")) ::
+ Nil
+ )
+ }
+
+ test("Type conflict in array elements") {
+ val xmlDF =
+ readData(
+ arrayElementTypeConflict,
+ Map("ignoreSurroundingSpaces" -> "true", "nullValue" -> ""))
+
+ val expectedSchema = StructType(
+ StructField(
+ "array1",
+ ArrayType(StructType(StructField("element", ArrayType(StringType)) :: Nil), true),
+ true
+ ) ::
+ StructField(
+ "array2",
+ ArrayType(StructType(StructField("field", LongType, true) :: Nil), true),
+ true
+ ) ::
+ StructField("array3", ArrayType(StringType, true), true) :: Nil
+ )
+
+ assert(xmlDF.schema === expectedSchema)
+ checkAnswer(
+ xmlDF,
+ Row(
+ Seq(
+ Row(List("1", "1.1", "true", null, "<array></array>", "<object></object>")),
+ Row(
+ List(
+ """<array>
+ | <element>2</element>
+ | <element>3</element>
+ | <element>4</element>
+ | </array>""".stripMargin,
+ """<object>
+ | <field>str</field>
+ | </object>""".stripMargin
+ )
+ )
+ ),
+ Seq(Row(214748364700L), Row(1)),
+ null
+ ) ::
+ Row(null, null, Seq("""<field>str</field>""", """<field>1</field>""")) ::
+ Row(null, null, Seq("1", "2", "3")) :: Nil
+ )
+ }
+
+ test("Handling missing fields") {
+ val xmlDF = readData(missingFields)
+
+ val expectedSchema = StructType(
+ StructField("a", BooleanType, true) ::
+ StructField("b", LongType, true) ::
+ StructField("c", ArrayType(LongType, true), true) ::
+ StructField("d", StructType(StructField("field", BooleanType, true) :: Nil), true) ::
+ StructField("e", StringType, true) :: Nil
+ )
+
+ assert(expectedSchema === xmlDF.schema)
+
+ }
+
+ test("Complex field and type inferring") {
+ val xmlDF = readData(complexFieldAndType1, Map("prefersDecimal" -> "true"))
+ val expectedSchema = StructType(
+ StructField(
+ "arrayOfArray1",
+ ArrayType(StructType(StructField("item", ArrayType(StringType, true)) :: Nil)),
+ true
+ ) ::
+ StructField(
+ "arrayOfArray2",
+ ArrayType(StructType(StructField("item", ArrayType(DecimalType(21, 1), true)) :: Nil), true)
+ ) ::
+ StructField("arrayOfBigInteger", ArrayType(DecimalType(21, 0), true), true) ::
+ StructField("arrayOfBoolean", ArrayType(BooleanType, true), true) ::
+ StructField("arrayOfDouble", ArrayType(DoubleType, true), true) ::
+ StructField("arrayOfInteger", ArrayType(LongType, true), true) ::
+ StructField("arrayOfLong", ArrayType(DecimalType(20, 0), true), true) ::
+ StructField("arrayOfNull", ArrayType(StringType, true), true) ::
+ StructField("arrayOfString", ArrayType(StringType, true), true) ::
+ StructField(
+ "arrayOfStruct",
+ ArrayType(
+ StructType(
+ StructField("field1", BooleanType, true) ::
+ StructField("field2", StringType, true) ::
+ StructField("field3", StringType, true) :: Nil
+ ),
+ true
+ ),
+ true
+ ) ::
+ StructField(
+ "struct",
+ StructType(
+ StructField("field1", BooleanType, true) ::
+ StructField("field2", DecimalType(20, 0), true) :: Nil
+ ),
+ true
+ ) ::
+ StructField(
+ "structWithArrayFields",
+ StructType(
+ StructField("field1", ArrayType(LongType, true), true) ::
+ StructField("field2", ArrayType(StringType, true), true) :: Nil
+ ),
+ true
+ ) :: Nil
+ )
+ assert(expectedSchema === xmlDF.schema)
+ }
+
+ test("complex arrays") {
+ val xmlDF = readData(complexFieldAndType2)
+ val expectedSchemaArrayOfArray1 = new StructType().add(
+ "arrayOfArray1",
+ ArrayType(
+ new StructType()
+ .add("array", ArrayType(new StructType().add("item", ArrayType(LongType))))
+ )
+ )
+ assert(xmlDF.select("arrayOfArray1").schema === expectedSchemaArrayOfArray1)
+ checkAnswer(
+ xmlDF.select("arrayOfArray1"),
+ Row(
+ Seq(
+ Row(Seq(Row(Seq(5)))),
+ Row(Seq(Row(Seq(6, 7)), Row(Seq(8))))
+ )
+ ) :: Nil
+ )
+ val expectedSchemaArrayOfArray2 = new StructType().add(
+ "arrayOfArray2",
+ ArrayType(
+ new StructType()
+ .add(
+ "array",
+ ArrayType(
+ new StructType().add(
+ "item",
+ ArrayType(
+ new StructType()
+ .add("inner1", StringType)
+ .add("inner2", ArrayType(StringType))
+ .add("inner3", ArrayType(new StructType().add("inner4", ArrayType(LongType))))
+ )
+ )
+ )
+ )
+ )
+ )
+ assert(xmlDF.select("arrayOfArray2").schema === expectedSchemaArrayOfArray2)
+ checkAnswer(
+ xmlDF.select("arrayOfArray2"),
+ Row(
+ Seq(
+ Row(Seq(Row(Seq(Row("str1", null, null))))),
+ Row(
+ Seq(
+ Row(null),
+ Row(Seq(Row(null, Seq("str3", "str33"), null), Row("str11", Seq("str4"), null)))
+ )
+ ),
+ Row(Seq(Row(Seq(Row(null, null, Seq(Row(Seq(2, 3)), Row(null)))))))
+ )
+ ) :: Nil
+ )
+ }
+
+ test("Complex field and type inferring with null in sampling") {
+ val xmlDF = readData(xmlNullStruct)
+ val expectedSchema = StructType(
+ StructField(
+ "headers",
+ StructType(
+ StructField("Charset", StringType, true) ::
+ StructField("Host", StringType, true) :: Nil
+ ),
+ true
+ ) ::
+ StructField("ip", StringType, true) ::
+ StructField("nullstr", StringType, true) :: Nil
+ )
+
+ assert(expectedSchema === xmlDF.schema)
+ checkAnswer(
+ xmlDF.select("nullStr", "headers.Host"),
+ Seq(Row("", "1.abc.com"), Row("", null), Row("", null), Row("", null))
+ )
+ }
+
+ test("empty records") {
+ val emptyDF = readData(emptyRecords)
+ val expectedSchema = new StructType()
+ .add(
+ "a",
+ new StructType()
+ .add(
+ "struct",
+ StructType(StructField("b", StructType(StructField("c", StringType) :: Nil)) :: Nil)))
+ .add(
+ "b",
+ new StructType()
+ .add(
+ "item",
+ ArrayType(
+ new StructType().add("c", StructType(StructField("struct", StringType) :: Nil)))))
+ assert(emptyDF.schema === expectedSchema)
+ }
+
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org