You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/02/07 13:02:30 UTC
spark git commit: [SPARK-16101][SQL] Refactoring CSV schema inference
path to be consistent with JSON
Repository: spark
Updated Branches:
refs/heads/master 8fd178d21 -> 3d314d08c
[SPARK-16101][SQL] Refactoring CSV schema inference path to be consistent with JSON
## What changes were proposed in this pull request?
This PR refactors CSV schema inference path to be consistent with JSON data source and moves some filtering codes having the similar/same logics into `CSVUtils`.
It makes the methods in classes have consistent arguments with JSON ones. (this PR renames `.../json/InferSchema.scala` \u2192 `.../json/JsonInferSchema.scala`)
`CSVInferSchema` and `JsonInferSchema`
``` scala
private[csv] object CSVInferSchema {
...
def infer(
csv: Dataset[String],
caseSensitive: Boolean,
options: CSVOptions): StructType = {
...
```
``` scala
private[sql] object JsonInferSchema {
...
def infer(
json: RDD[String],
columnNameOfCorruptRecord: String,
configOptions: JSONOptions): StructType = {
...
```
These allow schema inference from `Dataset[String]` directly, meaning the similar functionalities that use `JacksonParser`/`JsonInferSchema` for JSON can be easily implemented by `UnivocityParser`/`CSVInferSchema` for CSV.
This completes refactoring CSV datasource and they are now pretty consistent.
## How was this patch tested?
Existing tests should cover this and
```
./dev/change-scala-version.sh 2.10
./build/mvn -Pyarn -Phadoop-2.4 -Dscala-2.10 -DskipTests clean package
```
Author: hyukjinkwon <gu...@gmail.com>
Closes #16680 from HyukjinKwon/SPARK-16101-schema-inference.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3d314d08
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3d314d08
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3d314d08
Branch: refs/heads/master
Commit: 3d314d08c9420e74b4bb687603cdd11394eccab5
Parents: 8fd178d
Author: hyukjinkwon <gu...@gmail.com>
Authored: Tue Feb 7 21:02:20 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Tue Feb 7 21:02:20 2017 +0800
----------------------------------------------------------------------
.../org/apache/spark/sql/DataFrameReader.scala | 4 +-
.../datasources/csv/CSVFileFormat.scala | 112 +------
.../datasources/csv/CSVInferSchema.scala | 115 ++++---
.../execution/datasources/csv/CSVOptions.scala | 2 +-
.../execution/datasources/csv/CSVRelation.scala | 69 ----
.../execution/datasources/csv/CSVUtils.scala | 134 ++++++++
.../datasources/json/InferSchema.scala | 329 -------------------
.../datasources/json/JsonFileFormat.scala | 2 +-
.../datasources/json/JsonInferSchema.scala | 329 +++++++++++++++++++
.../datasources/csv/CSVUtilsSuite.scala | 47 +++
.../datasources/csv/UnivocityParserSuite.scala | 24 --
.../execution/datasources/json/JsonSuite.scala | 6 +-
12 files changed, 599 insertions(+), 574 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/3d314d08/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index a787d5a..1830839 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.datasources.jdbc._
-import org.apache.spark.sql.execution.datasources.json.InferSchema
+import org.apache.spark.sql.execution.datasources.json.JsonInferSchema
import org.apache.spark.sql.types.StructType
/**
@@ -334,7 +334,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
parsedOptions.columnNameOfCorruptRecord
.getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord)
val schema = userSpecifiedSchema.getOrElse {
- InferSchema.infer(
+ JsonInferSchema.infer(
jsonRDD,
columnNameOfCorruptRecord,
parsedOptions)
http://git-wip-us.apache.org/repos/asf/spark/blob/3d314d08/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
index 3897016..1d2bf07 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.datasources.csv
import java.nio.charset.{Charset, StandardCharsets}
-import com.univocity.parsers.csv.CsvParser
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.{LongWritable, Text}
@@ -28,13 +27,11 @@ import org.apache.hadoop.mapreduce._
import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
-import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.CompressionCodecs
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
-import org.apache.spark.sql.functions.{length, trim}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.util.SerializableConfiguration
@@ -60,64 +57,9 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
val csvOptions = new CSVOptions(options)
val paths = files.map(_.getPath.toString)
- val lines: Dataset[String] = readText(sparkSession, csvOptions, paths)
- val firstLine: String = findFirstLine(csvOptions, lines)
- val firstRow = new CsvParser(csvOptions.asParserSettings).parseLine(firstLine)
+ val lines: Dataset[String] = createBaseDataset(sparkSession, csvOptions, paths)
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
- val header = makeSafeHeader(firstRow, csvOptions, caseSensitive)
-
- val parsedRdd: RDD[Array[String]] = CSVRelation.univocityTokenizer(
- lines,
- firstLine = if (csvOptions.headerFlag) firstLine else null,
- params = csvOptions)
- val schema = if (csvOptions.inferSchemaFlag) {
- CSVInferSchema.infer(parsedRdd, header, csvOptions)
- } else {
- // By default fields are assumed to be StringType
- val schemaFields = header.map { fieldName =>
- StructField(fieldName, StringType, nullable = true)
- }
- StructType(schemaFields)
- }
- Some(schema)
- }
-
- /**
- * Generates a header from the given row which is null-safe and duplicate-safe.
- */
- private def makeSafeHeader(
- row: Array[String],
- options: CSVOptions,
- caseSensitive: Boolean): Array[String] = {
- if (options.headerFlag) {
- val duplicates = {
- val headerNames = row.filter(_ != null)
- .map(name => if (caseSensitive) name else name.toLowerCase)
- headerNames.diff(headerNames.distinct).distinct
- }
-
- row.zipWithIndex.map { case (value, index) =>
- if (value == null || value.isEmpty || value == options.nullValue) {
- // When there are empty strings or the values set in `nullValue`, put the
- // index as the suffix.
- s"_c$index"
- } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) {
- // When there are case-insensitive duplicates, put the index as the suffix.
- s"$value$index"
- } else if (duplicates.contains(value)) {
- // When there are duplicates, put the index as the suffix.
- s"$value$index"
- } else {
- value
- }
- }
- } else {
- row.zipWithIndex.map { case (_, index) =>
- // Uses default column names, "_c#" where # is its position of fields
- // when header option is disabled.
- s"_c$index"
- }
- }
+ Some(CSVInferSchema.infer(lines, caseSensitive, csvOptions))
}
override def prepareWrite(
@@ -125,7 +67,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
- verifySchema(dataSchema)
+ CSVUtils.verifySchema(dataSchema)
val conf = job.getConfiguration
val csvOptions = new CSVOptions(options)
csvOptions.compressionCodec.foreach { codec =>
@@ -155,13 +97,12 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
options: Map[String, String],
hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
val csvOptions = new CSVOptions(options)
- val commentPrefix = csvOptions.comment.toString
val broadcastedHadoopConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
(file: PartitionedFile) => {
- val lineIterator = {
+ val lines = {
val conf = broadcastedHadoopConf.value.value
val linesReader = new HadoopFileLinesReader(file, conf)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
@@ -170,32 +111,21 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
}
}
- // Consumes the header in the iterator.
- CSVRelation.dropHeaderLine(file, lineIterator, csvOptions)
-
- val filteredIter = lineIterator.filter { line =>
- line.trim.nonEmpty && !line.startsWith(commentPrefix)
+ val linesWithoutHeader = if (csvOptions.headerFlag && file.start == 0) {
+ // Note that if there are only comments in the first block, the header would probably
+ // be not dropped.
+ CSVUtils.dropHeaderLine(lines, csvOptions)
+ } else {
+ lines
}
+ val filteredLines = CSVUtils.filterCommentAndEmpty(linesWithoutHeader, csvOptions)
val parser = new UnivocityParser(dataSchema, requiredSchema, csvOptions)
- filteredIter.flatMap(parser.parse)
- }
- }
-
- /**
- * Returns the first line of the first non-empty file in path
- */
- private def findFirstLine(options: CSVOptions, lines: Dataset[String]): String = {
- import lines.sqlContext.implicits._
- val nonEmptyLines = lines.filter(length(trim($"value")) > 0)
- if (options.isCommentSet) {
- nonEmptyLines.filter(!$"value".startsWith(options.comment.toString)).first()
- } else {
- nonEmptyLines.first()
+ filteredLines.flatMap(parser.parse)
}
}
- private def readText(
+ private def createBaseDataset(
sparkSession: SparkSession,
options: CSVOptions,
inputPaths: Seq[String]): Dataset[String] = {
@@ -215,22 +145,6 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
sparkSession.createDataset(rdd)(Encoders.STRING)
}
}
-
- private def verifySchema(schema: StructType): Unit = {
- def verifyType(dataType: DataType): Unit = dataType match {
- case ByteType | ShortType | IntegerType | LongType | FloatType |
- DoubleType | BooleanType | _: DecimalType | TimestampType |
- DateType | StringType =>
-
- case udt: UserDefinedType[_] => verifyType(udt.sqlType)
-
- case _ =>
- throw new UnsupportedOperationException(
- s"CSV data source does not support ${dataType.simpleString} data type.")
- }
-
- schema.foreach(field => verifyType(field.dataType))
- }
}
private[csv] class CsvOutputWriter(
http://git-wip-us.apache.org/repos/asf/spark/blob/3d314d08/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
index 065bf53..485b186 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
@@ -18,17 +18,15 @@
package org.apache.spark.sql.execution.datasources.csv
import java.math.BigDecimal
-import java.text.NumberFormat
-import java.util.Locale
import scala.util.control.Exception._
-import scala.util.Try
-import org.apache.spark.rdd.RDD
+import com.univocity.parsers.csv.CsvParser
+
import org.apache.spark.sql.catalyst.analysis.TypeCoercion
import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
private[csv] object CSVInferSchema {
@@ -39,22 +37,76 @@ private[csv] object CSVInferSchema {
* 3. Replace any null types with string type
*/
def infer(
- tokenRdd: RDD[Array[String]],
- header: Array[String],
+ csv: Dataset[String],
+ caseSensitive: Boolean,
options: CSVOptions): StructType = {
- val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
- val rootTypes: Array[DataType] =
- tokenRdd.aggregate(startType)(inferRowType(options), mergeRowTypes)
-
- val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) =>
- val dType = rootType match {
- case _: NullType => StringType
- case other => other
+ val firstLine: String = CSVUtils.filterCommentAndEmpty(csv, options).first()
+ val firstRow = new CsvParser(options.asParserSettings).parseLine(firstLine)
+ val header = makeSafeHeader(firstRow, caseSensitive, options)
+
+ val fields = if (options.inferSchemaFlag) {
+ val tokenRdd = csv.rdd.mapPartitions { iter =>
+ val filteredLines = CSVUtils.filterCommentAndEmpty(iter, options)
+ val linesWithoutHeader = CSVUtils.filterHeaderLine(filteredLines, firstLine, options)
+ val parser = new CsvParser(options.asParserSettings)
+ linesWithoutHeader.map(parser.parseLine)
+ }
+
+ val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
+ val rootTypes: Array[DataType] =
+ tokenRdd.aggregate(startType)(inferRowType(options), mergeRowTypes)
+
+ header.zip(rootTypes).map { case (thisHeader, rootType) =>
+ val dType = rootType match {
+ case _: NullType => StringType
+ case other => other
+ }
+ StructField(thisHeader, dType, nullable = true)
}
- StructField(thisHeader, dType, nullable = true)
+ } else {
+ // By default fields are assumed to be StringType
+ header.map(fieldName => StructField(fieldName, StringType, nullable = true))
}
- StructType(structFields)
+ StructType(fields)
+ }
+
+ /**
+ * Generates a header from the given row which is null-safe and duplicate-safe.
+ */
+ private def makeSafeHeader(
+ row: Array[String],
+ caseSensitive: Boolean,
+ options: CSVOptions): Array[String] = {
+ if (options.headerFlag) {
+ val duplicates = {
+ val headerNames = row.filter(_ != null)
+ .map(name => if (caseSensitive) name else name.toLowerCase)
+ headerNames.diff(headerNames.distinct).distinct
+ }
+
+ row.zipWithIndex.map { case (value, index) =>
+ if (value == null || value.isEmpty || value == options.nullValue) {
+ // When there are empty strings or the values set in `nullValue`, put the
+ // index as the suffix.
+ s"_c$index"
+ } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) {
+ // When there are case-insensitive duplicates, put the index as the suffix.
+ s"$value$index"
+ } else if (duplicates.contains(value)) {
+ // When there are duplicates, put the index as the suffix.
+ s"$value$index"
+ } else {
+ value
+ }
+ }
+ } else {
+ row.zipWithIndex.map { case (_, index) =>
+ // Uses default column names, "_c#" where # is its position of fields
+ // when header option is disabled.
+ s"_c$index"
+ }
+ }
}
private def inferRowType(options: CSVOptions)
@@ -215,32 +267,3 @@ private[csv] object CSVInferSchema {
case _ => None
}
}
-
-private[csv] object CSVTypeCast {
- /**
- * Helper method that converts string representation of a character to actual character.
- * It handles some Java escaped strings and throws exception if given string is longer than one
- * character.
- */
- @throws[IllegalArgumentException]
- def toChar(str: String): Char = {
- if (str.charAt(0) == '\\') {
- str.charAt(1)
- match {
- case 't' => '\t'
- case 'r' => '\r'
- case 'b' => '\b'
- case 'f' => '\f'
- case '\"' => '\"' // In case user changes quote char and uses \" as delimiter in options
- case '\'' => '\''
- case 'u' if str == """\u0000""" => '\u0000'
- case _ =>
- throw new IllegalArgumentException(s"Unsupported special character for delimiter: $str")
- }
- } else if (str.length == 1) {
- str.charAt(0)
- } else {
- throw new IllegalArgumentException(s"Delimiter cannot be more than one character: $str")
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/3d314d08/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index 140ce23..af456c8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -69,7 +69,7 @@ private[csv] class CSVOptions(@transient private val parameters: CaseInsensitive
}
}
- val delimiter = CSVTypeCast.toChar(
+ val delimiter = CSVUtils.toChar(
parameters.getOrElse("sep", parameters.getOrElse("delimiter", ",")))
private val parseMode = parameters.getOrElse("mode", "PERMISSIVE")
val charset = parameters.getOrElse("encoding",
http://git-wip-us.apache.org/repos/asf/spark/blob/3d314d08/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
deleted file mode 100644
index 19058c2..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
+++ /dev/null
@@ -1,69 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.datasources.csv
-
-import com.univocity.parsers.csv.CsvParser
-
-import org.apache.spark.internal.Logging
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql._
-import org.apache.spark.sql.execution.datasources.PartitionedFile
-
-object CSVRelation extends Logging {
-
- def univocityTokenizer(
- file: Dataset[String],
- firstLine: String,
- params: CSVOptions): RDD[Array[String]] = {
- // If header is set, make sure firstLine is materialized before sending to executors.
- val commentPrefix = params.comment.toString
- file.rdd.mapPartitions { iter =>
- val parser = new CsvParser(params.asParserSettings)
- val filteredIter = iter.filter { line =>
- line.trim.nonEmpty && !line.startsWith(commentPrefix)
- }
- if (params.headerFlag) {
- filteredIter.filterNot(_ == firstLine).map { item =>
- parser.parseLine(item)
- }
- } else {
- filteredIter.map { item =>
- parser.parseLine(item)
- }
- }
- }
- }
-
- // Skips the header line of each file if the `header` option is set to true.
- def dropHeaderLine(
- file: PartitionedFile, lines: Iterator[String], csvOptions: CSVOptions): Unit = {
- // TODO What if the first partitioned file consists of only comments and empty lines?
- if (csvOptions.headerFlag && file.start == 0) {
- val nonEmptyLines = if (csvOptions.isCommentSet) {
- val commentPrefix = csvOptions.comment.toString
- lines.dropWhile { line =>
- line.trim.isEmpty || line.trim.startsWith(commentPrefix)
- }
- } else {
- lines.dropWhile(_.trim.isEmpty)
- }
-
- if (nonEmptyLines.hasNext) nonEmptyLines.drop(1)
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/3d314d08/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
new file mode 100644
index 0000000..72b053d
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.csv
+
+import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+
+object CSVUtils {
+ /**
+ * Filter ignorable rows for CSV dataset (lines empty and starting with `comment`).
+ * This is currently being used in CSV schema inference.
+ */
+ def filterCommentAndEmpty(lines: Dataset[String], options: CSVOptions): Dataset[String] = {
+ // Note that this was separately made by SPARK-18362. Logically, this should be the same
+ // with the one below, `filterCommentAndEmpty` but execution path is different. One of them
+ // might have to be removed in the near future if possible.
+ import lines.sqlContext.implicits._
+ val nonEmptyLines = lines.filter(length(trim($"value")) > 0)
+ if (options.isCommentSet) {
+ nonEmptyLines.filter(!$"value".startsWith(options.comment.toString))
+ } else {
+ nonEmptyLines
+ }
+ }
+
+ /**
+ * Filter ignorable rows for CSV iterator (lines empty and starting with `comment`).
+ * This is currently being used in CSV reading path and CSV schema inference.
+ */
+ def filterCommentAndEmpty(iter: Iterator[String], options: CSVOptions): Iterator[String] = {
+ iter.filter { line =>
+ line.trim.nonEmpty && !line.startsWith(options.comment.toString)
+ }
+ }
+
+ /**
+ * Skip the given first line so that only data can remain in a dataset.
+ * This is similar with `dropHeaderLine` below and currently being used in CSV schema inference.
+ */
+ def filterHeaderLine(
+ iter: Iterator[String],
+ firstLine: String,
+ options: CSVOptions): Iterator[String] = {
+ // Note that unlike actual CSV reading path, it simply filters the given first line. Therefore,
+ // this skips the line same with the header if exists. One of them might have to be removed
+ // in the near future if possible.
+ if (options.headerFlag) {
+ iter.filterNot(_ == firstLine)
+ } else {
+ iter
+ }
+ }
+
+ /**
+ * Drop header line so that only data can remain.
+ * This is similar with `filterHeaderLine` above and currently being used in CSV reading path.
+ */
+ def dropHeaderLine(iter: Iterator[String], options: CSVOptions): Iterator[String] = {
+ val nonEmptyLines = if (options.isCommentSet) {
+ val commentPrefix = options.comment.toString
+ iter.dropWhile { line =>
+ line.trim.isEmpty || line.trim.startsWith(commentPrefix)
+ }
+ } else {
+ iter.dropWhile(_.trim.isEmpty)
+ }
+
+ if (nonEmptyLines.hasNext) nonEmptyLines.drop(1)
+ iter
+ }
+
+ /**
+ * Helper method that converts string representation of a character to actual character.
+ * It handles some Java escaped strings and throws exception if given string is longer than one
+ * character.
+ */
+ @throws[IllegalArgumentException]
+ def toChar(str: String): Char = {
+ if (str.charAt(0) == '\\') {
+ str.charAt(1)
+ match {
+ case 't' => '\t'
+ case 'r' => '\r'
+ case 'b' => '\b'
+ case 'f' => '\f'
+ case '\"' => '\"' // In case user changes quote char and uses \" as delimiter in options
+ case '\'' => '\''
+ case 'u' if str == """\u0000""" => '\u0000'
+ case _ =>
+ throw new IllegalArgumentException(s"Unsupported special character for delimiter: $str")
+ }
+ } else if (str.length == 1) {
+ str.charAt(0)
+ } else {
+ throw new IllegalArgumentException(s"Delimiter cannot be more than one character: $str")
+ }
+ }
+
+ /**
+ * Verify if the schema is supported in CSV datasource.
+ */
+ def verifySchema(schema: StructType): Unit = {
+ def verifyType(dataType: DataType): Unit = dataType match {
+ case ByteType | ShortType | IntegerType | LongType | FloatType |
+ DoubleType | BooleanType | _: DecimalType | TimestampType |
+ DateType | StringType =>
+
+ case udt: UserDefinedType[_] => verifyType(udt.sqlType)
+
+ case _ =>
+ throw new UnsupportedOperationException(
+ s"CSV data source does not support ${dataType.simpleString} data type.")
+ }
+
+ schema.foreach(field => verifyType(field.dataType))
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/3d314d08/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
deleted file mode 100644
index 330d04d..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
+++ /dev/null
@@ -1,329 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.datasources.json
-
-import java.util.Comparator
-
-import com.fasterxml.jackson.core._
-
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.analysis.TypeCoercion
-import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil
-import org.apache.spark.sql.catalyst.json.JSONOptions
-import org.apache.spark.sql.types._
-import org.apache.spark.util.Utils
-
-private[sql] object InferSchema {
-
- /**
- * Infer the type of a collection of json records in three stages:
- * 1. Infer the type of each record
- * 2. Merge types by choosing the lowest type necessary to cover equal keys
- * 3. Replace any remaining null fields with string, the top type
- */
- def infer(
- json: RDD[String],
- columnNameOfCorruptRecord: String,
- configOptions: JSONOptions): StructType = {
- require(configOptions.samplingRatio > 0,
- s"samplingRatio (${configOptions.samplingRatio}) should be greater than 0")
- val shouldHandleCorruptRecord = configOptions.permissive
- val schemaData = if (configOptions.samplingRatio > 0.99) {
- json
- } else {
- json.sample(withReplacement = false, configOptions.samplingRatio, 1)
- }
-
- // perform schema inference on each row and merge afterwards
- val rootType = schemaData.mapPartitions { iter =>
- val factory = new JsonFactory()
- configOptions.setJacksonOptions(factory)
- iter.flatMap { row =>
- try {
- Utils.tryWithResource(factory.createParser(row)) { parser =>
- parser.nextToken()
- Some(inferField(parser, configOptions))
- }
- } catch {
- case _: JsonParseException if shouldHandleCorruptRecord =>
- Some(StructType(Seq(StructField(columnNameOfCorruptRecord, StringType))))
- case _: JsonParseException =>
- None
- }
- }
- }.fold(StructType(Seq()))(
- compatibleRootType(columnNameOfCorruptRecord, shouldHandleCorruptRecord))
-
- canonicalizeType(rootType) match {
- case Some(st: StructType) => st
- case _ =>
- // canonicalizeType erases all empty structs, including the only one we want to keep
- StructType(Seq())
- }
- }
-
- private[this] val structFieldComparator = new Comparator[StructField] {
- override def compare(o1: StructField, o2: StructField): Int = {
- o1.name.compare(o2.name)
- }
- }
-
- private def isSorted(arr: Array[StructField]): Boolean = {
- var i: Int = 0
- while (i < arr.length - 1) {
- if (structFieldComparator.compare(arr(i), arr(i + 1)) > 0) {
- return false
- }
- i += 1
- }
- true
- }
-
- /**
- * Infer the type of a json document from the parser's token stream
- */
- private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = {
- import com.fasterxml.jackson.core.JsonToken._
- parser.getCurrentToken match {
- case null | VALUE_NULL => NullType
-
- case FIELD_NAME =>
- parser.nextToken()
- inferField(parser, configOptions)
-
- case VALUE_STRING if parser.getTextLength < 1 =>
- // Zero length strings and nulls have special handling to deal
- // with JSON generators that do not distinguish between the two.
- // To accurately infer types for empty strings that are really
- // meant to represent nulls we assume that the two are isomorphic
- // but will defer treating null fields as strings until all the
- // record fields' types have been combined.
- NullType
-
- case VALUE_STRING => StringType
- case START_OBJECT =>
- val builder = Array.newBuilder[StructField]
- while (nextUntil(parser, END_OBJECT)) {
- builder += StructField(
- parser.getCurrentName,
- inferField(parser, configOptions),
- nullable = true)
- }
- val fields: Array[StructField] = builder.result()
- // Note: other code relies on this sorting for correctness, so don't remove it!
- java.util.Arrays.sort(fields, structFieldComparator)
- StructType(fields)
-
- case START_ARRAY =>
- // If this JSON array is empty, we use NullType as a placeholder.
- // If this array is not empty in other JSON objects, we can resolve
- // the type as we pass through all JSON objects.
- var elementType: DataType = NullType
- while (nextUntil(parser, END_ARRAY)) {
- elementType = compatibleType(
- elementType, inferField(parser, configOptions))
- }
-
- ArrayType(elementType)
-
- case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if configOptions.primitivesAsString => StringType
-
- case (VALUE_TRUE | VALUE_FALSE) if configOptions.primitivesAsString => StringType
-
- case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT =>
- import JsonParser.NumberType._
- parser.getNumberType match {
- // For Integer values, use LongType by default.
- case INT | LONG => LongType
- // Since we do not have a data type backed by BigInteger,
- // when we see a Java BigInteger, we use DecimalType.
- case BIG_INTEGER | BIG_DECIMAL =>
- val v = parser.getDecimalValue
- if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) {
- DecimalType(Math.max(v.precision(), v.scale()), v.scale())
- } else {
- DoubleType
- }
- case FLOAT | DOUBLE if configOptions.prefersDecimal =>
- val v = parser.getDecimalValue
- if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) {
- DecimalType(Math.max(v.precision(), v.scale()), v.scale())
- } else {
- DoubleType
- }
- case FLOAT | DOUBLE =>
- DoubleType
- }
-
- case VALUE_TRUE | VALUE_FALSE => BooleanType
- }
- }
-
- /**
- * Convert NullType to StringType and remove StructTypes with no fields
- */
- private def canonicalizeType(tpe: DataType): Option[DataType] = tpe match {
- case at @ ArrayType(elementType, _) =>
- for {
- canonicalType <- canonicalizeType(elementType)
- } yield {
- at.copy(canonicalType)
- }
-
- case StructType(fields) =>
- val canonicalFields: Array[StructField] = for {
- field <- fields
- if field.name.length > 0
- canonicalType <- canonicalizeType(field.dataType)
- } yield {
- field.copy(dataType = canonicalType)
- }
-
- if (canonicalFields.length > 0) {
- Some(StructType(canonicalFields))
- } else {
- // per SPARK-8093: empty structs should be deleted
- None
- }
-
- case NullType => Some(StringType)
- case other => Some(other)
- }
-
- private def withCorruptField(
- struct: StructType,
- columnNameOfCorruptRecords: String): StructType = {
- if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) {
- // If this given struct does not have a column used for corrupt records,
- // add this field.
- val newFields: Array[StructField] =
- StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields
- // Note: other code relies on this sorting for correctness, so don't remove it!
- java.util.Arrays.sort(newFields, structFieldComparator)
- StructType(newFields)
- } else {
- // Otherwise, just return this struct.
- struct
- }
- }
-
- /**
- * Remove top-level ArrayType wrappers and merge the remaining schemas
- */
- private def compatibleRootType(
- columnNameOfCorruptRecords: String,
- shouldHandleCorruptRecord: Boolean): (DataType, DataType) => DataType = {
- // Since we support array of json objects at the top level,
- // we need to check the element type and find the root level data type.
- case (ArrayType(ty1, _), ty2) =>
- compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)(ty1, ty2)
- case (ty1, ArrayType(ty2, _)) =>
- compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)(ty1, ty2)
- // If we see any other data type at the root level, we get records that cannot be
- // parsed. So, we use the struct as the data type and add the corrupt field to the schema.
- case (struct: StructType, NullType) => struct
- case (NullType, struct: StructType) => struct
- case (struct: StructType, o) if !o.isInstanceOf[StructType] && shouldHandleCorruptRecord =>
- withCorruptField(struct, columnNameOfCorruptRecords)
- case (o, struct: StructType) if !o.isInstanceOf[StructType] && shouldHandleCorruptRecord =>
- withCorruptField(struct, columnNameOfCorruptRecords)
- // If we get anything else, we call compatibleType.
- // Usually, when we reach here, ty1 and ty2 are two StructTypes.
- case (ty1, ty2) => compatibleType(ty1, ty2)
- }
-
- private[this] val emptyStructFieldArray = Array.empty[StructField]
-
- /**
- * Returns the most general data type for two given data types.
- */
- def compatibleType(t1: DataType, t2: DataType): DataType = {
- TypeCoercion.findTightestCommonType(t1, t2).getOrElse {
- // t1 or t2 is a StructType, ArrayType, or an unexpected type.
- (t1, t2) match {
- // Double support larger range than fixed decimal, DecimalType.Maximum should be enough
- // in most case, also have better precision.
- case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) =>
- DoubleType
-
- case (t1: DecimalType, t2: DecimalType) =>
- val scale = math.max(t1.scale, t2.scale)
- val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale)
- if (range + scale > 38) {
- // DecimalType can't support precision > 38
- DoubleType
- } else {
- DecimalType(range + scale, scale)
- }
-
- case (StructType(fields1), StructType(fields2)) =>
- // Both fields1 and fields2 should be sorted by name, since inferField performs sorting.
- // Therefore, we can take advantage of the fact that we're merging sorted lists and skip
- // building a hash map or performing additional sorting.
- assert(isSorted(fields1), s"StructType's fields were not sorted: ${fields1.toSeq}")
- assert(isSorted(fields2), s"StructType's fields were not sorted: ${fields2.toSeq}")
-
- val newFields = new java.util.ArrayList[StructField]()
-
- var f1Idx = 0
- var f2Idx = 0
-
- while (f1Idx < fields1.length && f2Idx < fields2.length) {
- val f1Name = fields1(f1Idx).name
- val f2Name = fields2(f2Idx).name
- val comp = f1Name.compareTo(f2Name)
- if (comp == 0) {
- val dataType = compatibleType(fields1(f1Idx).dataType, fields2(f2Idx).dataType)
- newFields.add(StructField(f1Name, dataType, nullable = true))
- f1Idx += 1
- f2Idx += 1
- } else if (comp < 0) { // f1Name < f2Name
- newFields.add(fields1(f1Idx))
- f1Idx += 1
- } else { // f1Name > f2Name
- newFields.add(fields2(f2Idx))
- f2Idx += 1
- }
- }
- while (f1Idx < fields1.length) {
- newFields.add(fields1(f1Idx))
- f1Idx += 1
- }
- while (f2Idx < fields2.length) {
- newFields.add(fields2(f2Idx))
- f2Idx += 1
- }
- StructType(newFields.toArray(emptyStructFieldArray))
-
- case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
- ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
-
- // The case that given `DecimalType` is capable of given `IntegralType` is handled in
- // `findTightestCommonTypeOfTwo`. Both cases below will be executed only when
- // the given `DecimalType` is not capable of the given `IntegralType`.
- case (t1: IntegralType, t2: DecimalType) =>
- compatibleType(DecimalType.forType(t1), t2)
- case (t1: DecimalType, t2: IntegralType) =>
- compatibleType(t1, DecimalType.forType(t2))
-
- // strings and every string is a Json object.
- case (_, _) => StringType
- }
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/3d314d08/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
index be1f94d..98ab9d2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
@@ -51,7 +51,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
val columnNameOfCorruptRecord =
parsedOptions.columnNameOfCorruptRecord
.getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord)
- val jsonSchema = InferSchema.infer(
+ val jsonSchema = JsonInferSchema.infer(
createBaseRdd(sparkSession, files),
columnNameOfCorruptRecord,
parsedOptions)
http://git-wip-us.apache.org/repos/asf/spark/blob/3d314d08/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
new file mode 100644
index 0000000..f51c18d
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
@@ -0,0 +1,329 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.json
+
+import java.util.Comparator
+
+import com.fasterxml.jackson.core._
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.analysis.TypeCoercion
+import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil
+import org.apache.spark.sql.catalyst.json.JSONOptions
+import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
+
+private[sql] object JsonInferSchema {
+
+ /**
+ * Infer the type of a collection of json records in three stages:
+ * 1. Infer the type of each record
+ * 2. Merge types by choosing the lowest type necessary to cover equal keys
+ * 3. Replace any remaining null fields with string, the top type
+ */
+ def infer(
+ json: RDD[String],
+ columnNameOfCorruptRecord: String,
+ configOptions: JSONOptions): StructType = {
+ require(configOptions.samplingRatio > 0,
+ s"samplingRatio (${configOptions.samplingRatio}) should be greater than 0")
+ val shouldHandleCorruptRecord = configOptions.permissive
+ val schemaData = if (configOptions.samplingRatio > 0.99) {
+ json
+ } else {
+ json.sample(withReplacement = false, configOptions.samplingRatio, 1)
+ }
+
+ // perform schema inference on each row and merge afterwards
+ val rootType = schemaData.mapPartitions { iter =>
+ val factory = new JsonFactory()
+ configOptions.setJacksonOptions(factory)
+ iter.flatMap { row =>
+ try {
+ Utils.tryWithResource(factory.createParser(row)) { parser =>
+ parser.nextToken()
+ Some(inferField(parser, configOptions))
+ }
+ } catch {
+ case _: JsonParseException if shouldHandleCorruptRecord =>
+ Some(StructType(Seq(StructField(columnNameOfCorruptRecord, StringType))))
+ case _: JsonParseException =>
+ None
+ }
+ }
+ }.fold(StructType(Seq()))(
+ compatibleRootType(columnNameOfCorruptRecord, shouldHandleCorruptRecord))
+
+ canonicalizeType(rootType) match {
+ case Some(st: StructType) => st
+ case _ =>
+ // canonicalizeType erases all empty structs, including the only one we want to keep
+ StructType(Seq())
+ }
+ }
+
+ private[this] val structFieldComparator = new Comparator[StructField] {
+ override def compare(o1: StructField, o2: StructField): Int = {
+ o1.name.compare(o2.name)
+ }
+ }
+
+ private def isSorted(arr: Array[StructField]): Boolean = {
+ var i: Int = 0
+ while (i < arr.length - 1) {
+ if (structFieldComparator.compare(arr(i), arr(i + 1)) > 0) {
+ return false
+ }
+ i += 1
+ }
+ true
+ }
+
+ /**
+ * Infer the type of a json document from the parser's token stream
+ */
+ private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = {
+ import com.fasterxml.jackson.core.JsonToken._
+ parser.getCurrentToken match {
+ case null | VALUE_NULL => NullType
+
+ case FIELD_NAME =>
+ parser.nextToken()
+ inferField(parser, configOptions)
+
+ case VALUE_STRING if parser.getTextLength < 1 =>
+ // Zero length strings and nulls have special handling to deal
+ // with JSON generators that do not distinguish between the two.
+ // To accurately infer types for empty strings that are really
+ // meant to represent nulls we assume that the two are isomorphic
+ // but will defer treating null fields as strings until all the
+ // record fields' types have been combined.
+ NullType
+
+ case VALUE_STRING => StringType
+ case START_OBJECT =>
+ val builder = Array.newBuilder[StructField]
+ while (nextUntil(parser, END_OBJECT)) {
+ builder += StructField(
+ parser.getCurrentName,
+ inferField(parser, configOptions),
+ nullable = true)
+ }
+ val fields: Array[StructField] = builder.result()
+ // Note: other code relies on this sorting for correctness, so don't remove it!
+ java.util.Arrays.sort(fields, structFieldComparator)
+ StructType(fields)
+
+ case START_ARRAY =>
+ // If this JSON array is empty, we use NullType as a placeholder.
+ // If this array is not empty in other JSON objects, we can resolve
+ // the type as we pass through all JSON objects.
+ var elementType: DataType = NullType
+ while (nextUntil(parser, END_ARRAY)) {
+ elementType = compatibleType(
+ elementType, inferField(parser, configOptions))
+ }
+
+ ArrayType(elementType)
+
+ case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if configOptions.primitivesAsString => StringType
+
+ case (VALUE_TRUE | VALUE_FALSE) if configOptions.primitivesAsString => StringType
+
+ case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT =>
+ import JsonParser.NumberType._
+ parser.getNumberType match {
+ // For Integer values, use LongType by default.
+ case INT | LONG => LongType
+ // Since we do not have a data type backed by BigInteger,
+ // when we see a Java BigInteger, we use DecimalType.
+ case BIG_INTEGER | BIG_DECIMAL =>
+ val v = parser.getDecimalValue
+ if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) {
+ DecimalType(Math.max(v.precision(), v.scale()), v.scale())
+ } else {
+ DoubleType
+ }
+ case FLOAT | DOUBLE if configOptions.prefersDecimal =>
+ val v = parser.getDecimalValue
+ if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) {
+ DecimalType(Math.max(v.precision(), v.scale()), v.scale())
+ } else {
+ DoubleType
+ }
+ case FLOAT | DOUBLE =>
+ DoubleType
+ }
+
+ case VALUE_TRUE | VALUE_FALSE => BooleanType
+ }
+ }
+
+ /**
+ * Convert NullType to StringType and remove StructTypes with no fields
+ */
+ private def canonicalizeType(tpe: DataType): Option[DataType] = tpe match {
+ case at @ ArrayType(elementType, _) =>
+ for {
+ canonicalType <- canonicalizeType(elementType)
+ } yield {
+ at.copy(canonicalType)
+ }
+
+ case StructType(fields) =>
+ val canonicalFields: Array[StructField] = for {
+ field <- fields
+ if field.name.length > 0
+ canonicalType <- canonicalizeType(field.dataType)
+ } yield {
+ field.copy(dataType = canonicalType)
+ }
+
+ if (canonicalFields.length > 0) {
+ Some(StructType(canonicalFields))
+ } else {
+ // per SPARK-8093: empty structs should be deleted
+ None
+ }
+
+ case NullType => Some(StringType)
+ case other => Some(other)
+ }
+
+ private def withCorruptField(
+ struct: StructType,
+ columnNameOfCorruptRecords: String): StructType = {
+ if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) {
+ // If this given struct does not have a column used for corrupt records,
+ // add this field.
+ val newFields: Array[StructField] =
+ StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields
+ // Note: other code relies on this sorting for correctness, so don't remove it!
+ java.util.Arrays.sort(newFields, structFieldComparator)
+ StructType(newFields)
+ } else {
+ // Otherwise, just return this struct.
+ struct
+ }
+ }
+
+ /**
+ * Remove top-level ArrayType wrappers and merge the remaining schemas
+ */
+ private def compatibleRootType(
+ columnNameOfCorruptRecords: String,
+ shouldHandleCorruptRecord: Boolean): (DataType, DataType) => DataType = {
+ // Since we support array of json objects at the top level,
+ // we need to check the element type and find the root level data type.
+ case (ArrayType(ty1, _), ty2) =>
+ compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)(ty1, ty2)
+ case (ty1, ArrayType(ty2, _)) =>
+ compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)(ty1, ty2)
+ // If we see any other data type at the root level, we get records that cannot be
+ // parsed. So, we use the struct as the data type and add the corrupt field to the schema.
+ case (struct: StructType, NullType) => struct
+ case (NullType, struct: StructType) => struct
+ case (struct: StructType, o) if !o.isInstanceOf[StructType] && shouldHandleCorruptRecord =>
+ withCorruptField(struct, columnNameOfCorruptRecords)
+ case (o, struct: StructType) if !o.isInstanceOf[StructType] && shouldHandleCorruptRecord =>
+ withCorruptField(struct, columnNameOfCorruptRecords)
+ // If we get anything else, we call compatibleType.
+ // Usually, when we reach here, ty1 and ty2 are two StructTypes.
+ case (ty1, ty2) => compatibleType(ty1, ty2)
+ }
+
+ private[this] val emptyStructFieldArray = Array.empty[StructField]
+
+ /**
+ * Returns the most general data type for two given data types.
+ */
+ def compatibleType(t1: DataType, t2: DataType): DataType = {
+ TypeCoercion.findTightestCommonType(t1, t2).getOrElse {
+ // t1 or t2 is a StructType, ArrayType, or an unexpected type.
+ (t1, t2) match {
+ // Double support larger range than fixed decimal, DecimalType.Maximum should be enough
+ // in most case, also have better precision.
+ case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) =>
+ DoubleType
+
+ case (t1: DecimalType, t2: DecimalType) =>
+ val scale = math.max(t1.scale, t2.scale)
+ val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale)
+ if (range + scale > 38) {
+ // DecimalType can't support precision > 38
+ DoubleType
+ } else {
+ DecimalType(range + scale, scale)
+ }
+
+ case (StructType(fields1), StructType(fields2)) =>
+ // Both fields1 and fields2 should be sorted by name, since inferField performs sorting.
+ // Therefore, we can take advantage of the fact that we're merging sorted lists and skip
+ // building a hash map or performing additional sorting.
+ assert(isSorted(fields1), s"StructType's fields were not sorted: ${fields1.toSeq}")
+ assert(isSorted(fields2), s"StructType's fields were not sorted: ${fields2.toSeq}")
+
+ val newFields = new java.util.ArrayList[StructField]()
+
+ var f1Idx = 0
+ var f2Idx = 0
+
+ while (f1Idx < fields1.length && f2Idx < fields2.length) {
+ val f1Name = fields1(f1Idx).name
+ val f2Name = fields2(f2Idx).name
+ val comp = f1Name.compareTo(f2Name)
+ if (comp == 0) {
+ val dataType = compatibleType(fields1(f1Idx).dataType, fields2(f2Idx).dataType)
+ newFields.add(StructField(f1Name, dataType, nullable = true))
+ f1Idx += 1
+ f2Idx += 1
+ } else if (comp < 0) { // f1Name < f2Name
+ newFields.add(fields1(f1Idx))
+ f1Idx += 1
+ } else { // f1Name > f2Name
+ newFields.add(fields2(f2Idx))
+ f2Idx += 1
+ }
+ }
+ while (f1Idx < fields1.length) {
+ newFields.add(fields1(f1Idx))
+ f1Idx += 1
+ }
+ while (f2Idx < fields2.length) {
+ newFields.add(fields2(f2Idx))
+ f2Idx += 1
+ }
+ StructType(newFields.toArray(emptyStructFieldArray))
+
+ case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
+ ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
+
+ // The case that given `DecimalType` is capable of given `IntegralType` is handled in
+ // `findTightestCommonTypeOfTwo`. Both cases below will be executed only when
+ // the given `DecimalType` is not capable of the given `IntegralType`.
+ case (t1: IntegralType, t2: DecimalType) =>
+ compatibleType(DecimalType.forType(t1), t2)
+ case (t1: DecimalType, t2: IntegralType) =>
+ compatibleType(t1, DecimalType.forType(t2))
+
+ // strings and every string is a Json object.
+ case (_, _) => StringType
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/3d314d08/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtilsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtilsSuite.scala
new file mode 100644
index 0000000..221e44c
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtilsSuite.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.csv
+
+import org.apache.spark.SparkFunSuite
+
+class CSVUtilsSuite extends SparkFunSuite {
+ test("Can parse escaped characters") {
+ assert(CSVUtils.toChar("""\t""") === '\t')
+ assert(CSVUtils.toChar("""\r""") === '\r')
+ assert(CSVUtils.toChar("""\b""") === '\b')
+ assert(CSVUtils.toChar("""\f""") === '\f')
+ assert(CSVUtils.toChar("""\"""") === '\"')
+ assert(CSVUtils.toChar("""\'""") === '\'')
+ assert(CSVUtils.toChar("""\u0000""") === '\u0000')
+ }
+
+ test("Does not accept delimiter larger than one character") {
+ val exception = intercept[IllegalArgumentException]{
+ CSVUtils.toChar("ab")
+ }
+ assert(exception.getMessage.contains("cannot be more than one character"))
+ }
+
+ test("Throws exception for unsupported escaped characters") {
+ val exception = intercept[IllegalArgumentException]{
+ CSVUtils.toChar("""\1""")
+ }
+ assert(exception.getMessage.contains("Unsupported special character for delimiter"))
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/3d314d08/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala
index 2ca6308..62dae08 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala
@@ -43,30 +43,6 @@ class UnivocityParserSuite extends SparkFunSuite {
}
}
- test("Can parse escaped characters") {
- assert(CSVTypeCast.toChar("""\t""") === '\t')
- assert(CSVTypeCast.toChar("""\r""") === '\r')
- assert(CSVTypeCast.toChar("""\b""") === '\b')
- assert(CSVTypeCast.toChar("""\f""") === '\f')
- assert(CSVTypeCast.toChar("""\"""") === '\"')
- assert(CSVTypeCast.toChar("""\'""") === '\'')
- assert(CSVTypeCast.toChar("""\u0000""") === '\u0000')
- }
-
- test("Does not accept delimiter larger than one character") {
- val exception = intercept[IllegalArgumentException]{
- CSVTypeCast.toChar("ab")
- }
- assert(exception.getMessage.contains("cannot be more than one character"))
- }
-
- test("Throws exception for unsupported escaped characters") {
- val exception = intercept[IllegalArgumentException]{
- CSVTypeCast.toChar("""\1""")
- }
- assert(exception.getMessage.contains("Unsupported special character for delimiter"))
- }
-
test("Nullable types are handled") {
val types = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType,
BooleanType, DecimalType.DoubleDecimal, TimestampType, DateType, StringType)
http://git-wip-us.apache.org/repos/asf/spark/blob/3d314d08/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 161a409..156fd96 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.DataSource
-import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType
+import org.apache.spark.sql.execution.datasources.json.JsonInferSchema.compatibleType
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
@@ -1366,7 +1366,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
test("SPARK-6245 JsonRDD.inferSchema on empty RDD") {
// This is really a test that it doesn't throw an exception
- val emptySchema = InferSchema.infer(empty, "", new JSONOptions(Map.empty[String, String]))
+ val emptySchema = JsonInferSchema.infer(empty, "", new JSONOptions(Map.empty[String, String]))
assert(StructType(Seq()) === emptySchema)
}
@@ -1390,7 +1390,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("SPARK-8093 Erase empty structs") {
- val emptySchema = InferSchema.infer(
+ val emptySchema = JsonInferSchema.infer(
emptyRecords, "", new JSONOptions(Map.empty[String, String]))
assert(StructType(Seq()) === emptySchema)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org