You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2019/03/26 22:58:48 UTC
[spark] branch master updated: [SPARK-27269][SQL] File source v2
should validate data schema only
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 267160b [SPARK-27269][SQL] File source v2 should validate data schema only
267160b is described below
commit 267160b36032c81b9d89ea4a90d3757a1f1e98b7
Author: Gengliang Wang <ge...@databricks.com>
AuthorDate: Wed Mar 27 07:58:31 2019 +0900
[SPARK-27269][SQL] File source v2 should validate data schema only
## What changes were proposed in this pull request?
Currently, File source v2 allows each data source to specify the supported data types by implementing the method `supportsDataType` in `FileScan` and `FileWriteBuilder`.
However, in the read path, the validation checks all the data types in `readSchema`, which might contain partition columns. This is actually a regression. E.g. Text data source only supports String data type, while the partition columns can still contain Integer type since partition columns are processed by Spark.
This PR is to:
1. Refactor schema validation and check data schema only.
2. Filter the partition columns in data schema if user specified schema provided.
## How was this patch tested?
Unit test
Closes #24203 from gengliangwang/schemaValidation.
Authored-by: Gengliang Wang <ge...@databricks.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
.../sql/execution/datasources/v2/FileScan.scala | 30 +-------
.../sql/execution/datasources/v2/FileTable.scala | 30 +++++++-
.../datasources/v2/FileWriteBuilder.scala | 22 ++----
.../datasources/v2/csv/CSVDataSourceV2.scala | 12 +---
.../sql/execution/datasources/v2/csv/CSVScan.scala | 6 --
.../execution/datasources/v2/csv/CSVTable.scala | 14 +++-
.../datasources/v2/csv/CSVWriteBuilder.scala | 14 ++--
.../datasources/v2/orc/OrcDataSourceV2.scala | 18 +----
.../sql/execution/datasources/v2/orc/OrcScan.scala | 8 +--
.../execution/datasources/v2/orc/OrcTable.scala | 21 +++++-
.../datasources/v2/orc/OrcWriteBuilder.scala | 14 ++--
.../execution/datasources/v2/FileTableSuite.scala | 84 ++++++++++++++++++++++
12 files changed, 166 insertions(+), 107 deletions(-)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala
index 6ab5c4b..e971fd7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala
@@ -18,11 +18,11 @@ package org.apache.spark.sql.execution.datasources.v2
import org.apache.hadoop.fs.Path
-import org.apache.spark.sql.{AnalysisException, SparkSession}
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.PartitionedFileUtil
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, Scan}
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
abstract class FileScan(
@@ -37,22 +37,6 @@ abstract class FileScan(
false
}
- /**
- * Returns whether this format supports the given [[DataType]] in write path.
- * By default all data types are supported.
- */
- def supportsDataType(dataType: DataType): Boolean = true
-
- /**
- * The string that represents the format that this data source provider uses. This is
- * overridden by children to provide a nice alias for the data source. For example:
- *
- * {{{
- * override def formatName(): String = "ORC"
- * }}}
- */
- def formatName: String
-
protected def partitions: Seq[FilePartition] = {
val selectedPartitions = fileIndex.listFiles(Seq.empty, Seq.empty)
val maxSplitBytes = FilePartition.maxSplitBytes(sparkSession, selectedPartitions)
@@ -76,13 +60,5 @@ abstract class FileScan(
partitions.toArray
}
- override def toBatch: Batch = {
- readSchema.foreach { field =>
- if (!supportsDataType(field.dataType)) {
- throw new AnalysisException(
- s"$formatName data source does not support ${field.dataType.catalogString} data type.")
- }
- }
- this
- }
+ override def toBatch: Batch = this
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala
index 4b35df3..188016c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources.v2.{SupportsRead, SupportsWrite, Table, TableCapability}
import org.apache.spark.sql.sources.v2.TableCapability._
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.util.SchemaUtils
@@ -46,7 +46,11 @@ abstract class FileTable(
sparkSession, rootPathsSpecified, caseSensitiveMap, userSpecifiedSchema, fileStatusCache)
}
- lazy val dataSchema: StructType = userSpecifiedSchema.orElse {
+ lazy val dataSchema: StructType = userSpecifiedSchema.map { schema =>
+ val partitionSchema = fileIndex.partitionSchema
+ val resolver = sparkSession.sessionState.conf.resolver
+ StructType(schema.filterNot(f => partitionSchema.exists(p => resolver(p.name, f.name))))
+ }.orElse {
inferSchema(fileIndex.allFiles())
}.getOrElse {
throw new AnalysisException(
@@ -57,6 +61,12 @@ abstract class FileTable(
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
SchemaUtils.checkColumnNameDuplication(dataSchema.fieldNames,
"in the data schema", caseSensitive)
+ dataSchema.foreach { field =>
+ if (!supportsDataType(field.dataType)) {
+ throw new AnalysisException(
+ s"$formatName data source does not support ${field.dataType.catalogString} data type.")
+ }
+ }
val partitionSchema = fileIndex.partitionSchema
SchemaUtils.checkColumnNameDuplication(partitionSchema.fieldNames,
"in the partition schema", caseSensitive)
@@ -72,6 +82,22 @@ abstract class FileTable(
* Spark will require that user specify the schema manually.
*/
def inferSchema(files: Seq[FileStatus]): Option[StructType]
+
+ /**
+ * Returns whether this format supports the given [[DataType]] in read/write path.
+ * By default all data types are supported.
+ */
+ def supportsDataType(dataType: DataType): Boolean = true
+
+ /**
+ * The string that represents the format that this data source provider uses. This is
+ * overridden by children to provide a nice alias for the data source. For example:
+ *
+ * {{{
+ * override def formatName(): String = "ORC"
+ * }}}
+ */
+ def formatName: String
}
object FileTable {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala
index bb4a428..7ff5c41 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala
@@ -39,7 +39,11 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.util.SerializableConfiguration
-abstract class FileWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[String])
+abstract class FileWriteBuilder(
+ options: CaseInsensitiveStringMap,
+ paths: Seq[String],
+ formatName: String,
+ supportsDataType: DataType => Boolean)
extends WriteBuilder with SupportsSaveMode {
private var schema: StructType = _
private var queryId: String = _
@@ -108,22 +112,6 @@ abstract class FileWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[St
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory
- /**
- * Returns whether this format supports the given [[DataType]] in write path.
- * By default all data types are supported.
- */
- def supportsDataType(dataType: DataType): Boolean = true
-
- /**
- * The string that represents the format that this data source provider uses. This is
- * overridden by children to provide a nice alias for the data source. For example:
- *
- * {{{
- * override def formatName(): String = "ORC"
- * }}}
- */
- def formatName: String
-
private def validateInputs(caseSensitiveAnalysis: Boolean): Unit = {
assert(schema != null, "Missing input data schema")
assert(queryId != null, "Missing query ID")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala
index 4ecd9cd..55222c6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVDataSourceV2.scala
@@ -20,7 +20,7 @@ import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
import org.apache.spark.sql.sources.v2.Table
-import org.apache.spark.sql.types._
+import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
class CSVDataSourceV2 extends FileDataSourceV2 {
@@ -41,13 +41,3 @@ class CSVDataSourceV2 extends FileDataSourceV2 {
CSVTable(tableName, sparkSession, options, paths, Some(schema))
}
}
-
-object CSVDataSourceV2 {
- def supportsDataType(dataType: DataType): Boolean = dataType match {
- case _: AtomicType => true
-
- case udt: UserDefinedType[_] => supportsDataType(udt.sqlType)
-
- case _ => false
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala
index 35c6a66..8f2f8f2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala
@@ -75,10 +75,4 @@ case class CSVScan(
CSVPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf,
dataSchema, fileIndex.partitionSchema, readSchema, parsedOptions)
}
-
- override def supportsDataType(dataType: DataType): Boolean = {
- CSVDataSourceV2.supportsDataType(dataType)
- }
-
- override def formatName: String = "CSV"
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala
index bf4b8ba..852cbf0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.csv.CSVOptions
import org.apache.spark.sql.execution.datasources.csv.CSVDataSource
import org.apache.spark.sql.execution.datasources.v2.FileTable
import org.apache.spark.sql.sources.v2.writer.WriteBuilder
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{AtomicType, DataType, StructType, UserDefinedType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
case class CSVTable(
@@ -48,5 +48,15 @@ case class CSVTable(
}
override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder =
- new CSVWriteBuilder(options, paths)
+ new CSVWriteBuilder(options, paths, formatName, supportsDataType)
+
+ override def supportsDataType(dataType: DataType): Boolean = dataType match {
+ case _: AtomicType => true
+
+ case udt: UserDefinedType[_] => supportsDataType(udt.sqlType)
+
+ case _ => false
+ }
+
+ override def formatName: String = "CSV"
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVWriteBuilder.scala
index bb26d2f..92b47e4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVWriteBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVWriteBuilder.scala
@@ -27,8 +27,12 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
-class CSVWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[String])
- extends FileWriteBuilder(options, paths) {
+class CSVWriteBuilder(
+ options: CaseInsensitiveStringMap,
+ paths: Seq[String],
+ formatName: String,
+ supportsDataType: DataType => Boolean)
+ extends FileWriteBuilder(options, paths, formatName, supportsDataType) {
override def prepareWrite(
sqlConf: SQLConf,
job: Job,
@@ -56,10 +60,4 @@ class CSVWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[String])
}
}
}
-
- override def supportsDataType(dataType: DataType): Boolean = {
- CSVDataSourceV2.supportsDataType(dataType)
- }
-
- override def formatName: String = "CSV"
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala
index 36e7e12..e8b9e6c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala
@@ -20,7 +20,7 @@ import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.v2._
import org.apache.spark.sql.sources.v2.Table
-import org.apache.spark.sql.types._
+import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
class OrcDataSourceV2 extends FileDataSourceV2 {
@@ -42,19 +42,3 @@ class OrcDataSourceV2 extends FileDataSourceV2 {
}
}
-object OrcDataSourceV2 {
- def supportsDataType(dataType: DataType): Boolean = dataType match {
- case _: AtomicType => true
-
- case st: StructType => st.forall { f => supportsDataType(f.dataType) }
-
- case ArrayType(elementType, _) => supportsDataType(elementType)
-
- case MapType(keyType, valueType, _) =>
- supportsDataType(keyType) && supportsDataType(valueType)
-
- case udt: UserDefinedType[_] => supportsDataType(udt.sqlType)
-
- case _ => false
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala
index 237eadb..fc8a682 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.v2.FileScan
import org.apache.spark.sql.sources.v2.reader.PartitionReaderFactory
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.SerializableConfiguration
@@ -43,10 +43,4 @@ case class OrcScan(
OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf,
dataSchema, fileIndex.partitionSchema, readSchema)
}
-
- override def supportsDataType(dataType: DataType): Boolean = {
- OrcDataSourceV2.supportsDataType(dataType)
- }
-
- override def formatName: String = "ORC"
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala
index aac38fb..ace77b7c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.datasources.orc.OrcUtils
import org.apache.spark.sql.execution.datasources.v2.FileTable
import org.apache.spark.sql.sources.v2.writer.WriteBuilder
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
case class OrcTable(
@@ -40,5 +40,22 @@ case class OrcTable(
OrcUtils.readSchema(sparkSession, files)
override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder =
- new OrcWriteBuilder(options, paths)
+ new OrcWriteBuilder(options, paths, formatName, supportsDataType)
+
+ override def supportsDataType(dataType: DataType): Boolean = dataType match {
+ case _: AtomicType => true
+
+ case st: StructType => st.forall { f => supportsDataType(f.dataType) }
+
+ case ArrayType(elementType, _) => supportsDataType(elementType)
+
+ case MapType(keyType, valueType, _) =>
+ supportsDataType(keyType) && supportsDataType(valueType)
+
+ case udt: UserDefinedType[_] => supportsDataType(udt.sqlType)
+
+ case _ => false
+ }
+
+ override def formatName: String = "ORC"
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala
index 829ab5f..f5b06e1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala
@@ -28,8 +28,12 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
-class OrcWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[String])
- extends FileWriteBuilder(options, paths) {
+class OrcWriteBuilder(
+ options: CaseInsensitiveStringMap,
+ paths: Seq[String],
+ formatName: String,
+ supportsDataType: DataType => Boolean)
+ extends FileWriteBuilder(options, paths, formatName, supportsDataType) {
override def prepareWrite(
sqlConf: SQLConf,
@@ -65,10 +69,4 @@ class OrcWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[String])
}
}
}
-
- override def supportsDataType(dataType: DataType): Boolean = {
- OrcDataSourceV2.supportsDataType(dataType)
- }
-
- override def formatName: String = "ORC"
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/FileTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/FileTableSuite.scala
new file mode 100644
index 0000000..3d4f564
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/FileTableSuite.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.datasources.v2
+
+import scala.collection.JavaConverters._
+
+import org.apache.hadoop.fs.FileStatus
+
+import org.apache.spark.sql.{QueryTest, SparkSession}
+import org.apache.spark.sql.sources.v2.reader.ScanBuilder
+import org.apache.spark.sql.sources.v2.writer.WriteBuilder
+import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+
+class DummyFileTable(
+ sparkSession: SparkSession,
+ options: CaseInsensitiveStringMap,
+ paths: Seq[String],
+ expectedDataSchema: StructType,
+ userSpecifiedSchema: Option[StructType])
+ extends FileTable(sparkSession, options, paths, userSpecifiedSchema) {
+ override def inferSchema(files: Seq[FileStatus]): Option[StructType] = Some(expectedDataSchema)
+
+ override def name(): String = "Dummy"
+
+ override def formatName: String = "Dummy"
+
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = null
+
+ override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = null
+
+ override def supportsDataType(dataType: DataType): Boolean = dataType == StringType
+}
+
+class FileTableSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
+
+ test("Data type validation should check data schema only") {
+ withTempPath { dir =>
+ val df = spark.createDataFrame(Seq(("a", 1), ("b", 2))).toDF("v", "p")
+ val pathName = dir.getCanonicalPath
+ df.write.partitionBy("p").text(pathName)
+ val options = new CaseInsensitiveStringMap(Map("path" -> pathName).asJava)
+ val expectedDataSchema = StructType(Seq(StructField("v", StringType, true)))
+ // DummyFileTable doesn't support Integer data type.
+ // However, the partition schema is handled by Spark, so it is allowed to contain
+ // Integer data type here.
+ val table = new DummyFileTable(spark, options, Seq(pathName), expectedDataSchema, None)
+ assert(table.dataSchema == expectedDataSchema)
+ val expectedPartitionSchema = StructType(Seq(StructField("p", IntegerType, true)))
+ assert(table.fileIndex.partitionSchema == expectedPartitionSchema)
+ }
+ }
+
+ test("Returns correct data schema when user specified schema contains partition schema") {
+ withTempPath { dir =>
+ val df = spark.createDataFrame(Seq(("a", 1), ("b", 2))).toDF("v", "p")
+ val pathName = dir.getCanonicalPath
+ df.write.partitionBy("p").text(pathName)
+ val options = new CaseInsensitiveStringMap(Map("path" -> pathName).asJava)
+ val userSpecifiedSchema = Some(StructType(Seq(
+ StructField("v", StringType, true),
+ StructField("p", IntegerType, true))))
+ val expectedDataSchema = StructType(Seq(StructField("v", StringType, true)))
+ val table =
+ new DummyFileTable(spark, options, Seq(pathName), expectedDataSchema, userSpecifiedSchema)
+ assert(table.dataSchema == expectedDataSchema)
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org