You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/02/07 13:02:30 UTC

spark git commit: [SPARK-16101][SQL] Refactoring CSV schema inference path to be consistent with JSON

Repository: spark
Updated Branches:
  refs/heads/master 8fd178d21 -> 3d314d08c


[SPARK-16101][SQL] Refactoring CSV schema inference path to be consistent with JSON

## What changes were proposed in this pull request?

This PR refactors CSV schema inference path to be consistent with JSON data source and moves some filtering codes having the similar/same logics into `CSVUtils`.

 It makes the methods in classes have consistent arguments with JSON ones. (this PR renames `.../json/InferSchema.scala` \u2192 `.../json/JsonInferSchema.scala`)

`CSVInferSchema` and `JsonInferSchema`

``` scala
private[csv] object CSVInferSchema {
  ...

  def infer(
      csv: Dataset[String],
      caseSensitive: Boolean,
      options: CSVOptions): StructType = {
  ...
```

``` scala
private[sql] object JsonInferSchema {
  ...

  def infer(
      json: RDD[String],
      columnNameOfCorruptRecord: String,
      configOptions: JSONOptions): StructType = {
  ...
```

These allow schema inference from `Dataset[String]` directly, meaning the similar functionalities that use `JacksonParser`/`JsonInferSchema` for JSON can be easily implemented by `UnivocityParser`/`CSVInferSchema` for CSV.

This completes refactoring CSV datasource and they are now pretty consistent.

## How was this patch tested?

Existing tests should cover this and

```
./dev/change-scala-version.sh 2.10
./build/mvn -Pyarn -Phadoop-2.4 -Dscala-2.10 -DskipTests clean package
```

Author: hyukjinkwon <gu...@gmail.com>

Closes #16680 from HyukjinKwon/SPARK-16101-schema-inference.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3d314d08
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3d314d08
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3d314d08

Branch: refs/heads/master
Commit: 3d314d08c9420e74b4bb687603cdd11394eccab5
Parents: 8fd178d
Author: hyukjinkwon <gu...@gmail.com>
Authored: Tue Feb 7 21:02:20 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Tue Feb 7 21:02:20 2017 +0800

----------------------------------------------------------------------
 .../org/apache/spark/sql/DataFrameReader.scala  |   4 +-
 .../datasources/csv/CSVFileFormat.scala         | 112 +------
 .../datasources/csv/CSVInferSchema.scala        | 115 ++++---
 .../execution/datasources/csv/CSVOptions.scala  |   2 +-
 .../execution/datasources/csv/CSVRelation.scala |  69 ----
 .../execution/datasources/csv/CSVUtils.scala    | 134 ++++++++
 .../datasources/json/InferSchema.scala          | 329 -------------------
 .../datasources/json/JsonFileFormat.scala       |   2 +-
 .../datasources/json/JsonInferSchema.scala      | 329 +++++++++++++++++++
 .../datasources/csv/CSVUtilsSuite.scala         |  47 +++
 .../datasources/csv/UnivocityParserSuite.scala  |  24 --
 .../execution/datasources/json/JsonSuite.scala  |   6 +-
 12 files changed, 599 insertions(+), 574 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3d314d08/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index a787d5a..1830839 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.LogicalRDD
 import org.apache.spark.sql.execution.command.DDLUtils
 import org.apache.spark.sql.execution.datasources.DataSource
 import org.apache.spark.sql.execution.datasources.jdbc._
-import org.apache.spark.sql.execution.datasources.json.InferSchema
+import org.apache.spark.sql.execution.datasources.json.JsonInferSchema
 import org.apache.spark.sql.types.StructType
 
 /**
@@ -334,7 +334,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
       parsedOptions.columnNameOfCorruptRecord
         .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord)
     val schema = userSpecifiedSchema.getOrElse {
-      InferSchema.infer(
+      JsonInferSchema.infer(
         jsonRDD,
         columnNameOfCorruptRecord,
         parsedOptions)

http://git-wip-us.apache.org/repos/asf/spark/blob/3d314d08/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
index 3897016..1d2bf07 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.datasources.csv
 
 import java.nio.charset.{Charset, StandardCharsets}
 
-import com.univocity.parsers.csv.CsvParser
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.{FileStatus, Path}
 import org.apache.hadoop.io.{LongWritable, Text}
@@ -28,13 +27,11 @@ import org.apache.hadoop.mapreduce._
 
 import org.apache.spark.TaskContext
 import org.apache.spark.internal.Logging
-import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.util.CompressionCodecs
 import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.sql.execution.datasources.text.TextFileFormat
-import org.apache.spark.sql.functions.{length, trim}
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.types._
 import org.apache.spark.util.SerializableConfiguration
@@ -60,64 +57,9 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
 
     val csvOptions = new CSVOptions(options)
     val paths = files.map(_.getPath.toString)
-    val lines: Dataset[String] = readText(sparkSession, csvOptions, paths)
-    val firstLine: String = findFirstLine(csvOptions, lines)
-    val firstRow = new CsvParser(csvOptions.asParserSettings).parseLine(firstLine)
+    val lines: Dataset[String] = createBaseDataset(sparkSession, csvOptions, paths)
     val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
-    val header = makeSafeHeader(firstRow, csvOptions, caseSensitive)
-
-    val parsedRdd: RDD[Array[String]] = CSVRelation.univocityTokenizer(
-      lines,
-      firstLine = if (csvOptions.headerFlag) firstLine else null,
-      params = csvOptions)
-    val schema = if (csvOptions.inferSchemaFlag) {
-      CSVInferSchema.infer(parsedRdd, header, csvOptions)
-    } else {
-      // By default fields are assumed to be StringType
-      val schemaFields = header.map { fieldName =>
-        StructField(fieldName, StringType, nullable = true)
-      }
-      StructType(schemaFields)
-    }
-    Some(schema)
-  }
-
-  /**
-   * Generates a header from the given row which is null-safe and duplicate-safe.
-   */
-  private def makeSafeHeader(
-      row: Array[String],
-      options: CSVOptions,
-      caseSensitive: Boolean): Array[String] = {
-    if (options.headerFlag) {
-      val duplicates = {
-        val headerNames = row.filter(_ != null)
-          .map(name => if (caseSensitive) name else name.toLowerCase)
-        headerNames.diff(headerNames.distinct).distinct
-      }
-
-      row.zipWithIndex.map { case (value, index) =>
-        if (value == null || value.isEmpty || value == options.nullValue) {
-          // When there are empty strings or the values set in `nullValue`, put the
-          // index as the suffix.
-          s"_c$index"
-        } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) {
-          // When there are case-insensitive duplicates, put the index as the suffix.
-          s"$value$index"
-        } else if (duplicates.contains(value)) {
-          // When there are duplicates, put the index as the suffix.
-          s"$value$index"
-        } else {
-          value
-        }
-      }
-    } else {
-      row.zipWithIndex.map { case (_, index) =>
-        // Uses default column names, "_c#" where # is its position of fields
-        // when header option is disabled.
-        s"_c$index"
-      }
-    }
+    Some(CSVInferSchema.infer(lines, caseSensitive, csvOptions))
   }
 
   override def prepareWrite(
@@ -125,7 +67,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
       job: Job,
       options: Map[String, String],
       dataSchema: StructType): OutputWriterFactory = {
-    verifySchema(dataSchema)
+    CSVUtils.verifySchema(dataSchema)
     val conf = job.getConfiguration
     val csvOptions = new CSVOptions(options)
     csvOptions.compressionCodec.foreach { codec =>
@@ -155,13 +97,12 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
       options: Map[String, String],
       hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
     val csvOptions = new CSVOptions(options)
-    val commentPrefix = csvOptions.comment.toString
 
     val broadcastedHadoopConf =
       sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
 
     (file: PartitionedFile) => {
-      val lineIterator = {
+      val lines = {
         val conf = broadcastedHadoopConf.value.value
         val linesReader = new HadoopFileLinesReader(file, conf)
         Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
@@ -170,32 +111,21 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
         }
       }
 
-      // Consumes the header in the iterator.
-      CSVRelation.dropHeaderLine(file, lineIterator, csvOptions)
-
-      val filteredIter = lineIterator.filter { line =>
-        line.trim.nonEmpty && !line.startsWith(commentPrefix)
+      val linesWithoutHeader = if (csvOptions.headerFlag && file.start == 0) {
+        // Note that if there are only comments in the first block, the header would probably
+        // be not dropped.
+        CSVUtils.dropHeaderLine(lines, csvOptions)
+      } else {
+        lines
       }
 
+      val filteredLines = CSVUtils.filterCommentAndEmpty(linesWithoutHeader, csvOptions)
       val parser = new UnivocityParser(dataSchema, requiredSchema, csvOptions)
-      filteredIter.flatMap(parser.parse)
-    }
-  }
-
-  /**
-   * Returns the first line of the first non-empty file in path
-   */
-  private def findFirstLine(options: CSVOptions, lines: Dataset[String]): String = {
-    import lines.sqlContext.implicits._
-    val nonEmptyLines = lines.filter(length(trim($"value")) > 0)
-    if (options.isCommentSet) {
-      nonEmptyLines.filter(!$"value".startsWith(options.comment.toString)).first()
-    } else {
-      nonEmptyLines.first()
+      filteredLines.flatMap(parser.parse)
     }
   }
 
-  private def readText(
+  private def createBaseDataset(
       sparkSession: SparkSession,
       options: CSVOptions,
       inputPaths: Seq[String]): Dataset[String] = {
@@ -215,22 +145,6 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
       sparkSession.createDataset(rdd)(Encoders.STRING)
     }
   }
-
-  private def verifySchema(schema: StructType): Unit = {
-    def verifyType(dataType: DataType): Unit = dataType match {
-        case ByteType | ShortType | IntegerType | LongType | FloatType |
-             DoubleType | BooleanType | _: DecimalType | TimestampType |
-             DateType | StringType =>
-
-        case udt: UserDefinedType[_] => verifyType(udt.sqlType)
-
-        case _ =>
-          throw new UnsupportedOperationException(
-            s"CSV data source does not support ${dataType.simpleString} data type.")
-    }
-
-    schema.foreach(field => verifyType(field.dataType))
-  }
 }
 
 private[csv] class CsvOutputWriter(

http://git-wip-us.apache.org/repos/asf/spark/blob/3d314d08/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
index 065bf53..485b186 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala
@@ -18,17 +18,15 @@
 package org.apache.spark.sql.execution.datasources.csv
 
 import java.math.BigDecimal
-import java.text.NumberFormat
-import java.util.Locale
 
 import scala.util.control.Exception._
-import scala.util.Try
 
-import org.apache.spark.rdd.RDD
+import com.univocity.parsers.csv.CsvParser
+
 import org.apache.spark.sql.catalyst.analysis.TypeCoercion
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.Dataset
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
 
 private[csv] object CSVInferSchema {
 
@@ -39,22 +37,76 @@ private[csv] object CSVInferSchema {
    *     3. Replace any null types with string type
    */
   def infer(
-      tokenRdd: RDD[Array[String]],
-      header: Array[String],
+      csv: Dataset[String],
+      caseSensitive: Boolean,
       options: CSVOptions): StructType = {
-    val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
-    val rootTypes: Array[DataType] =
-      tokenRdd.aggregate(startType)(inferRowType(options), mergeRowTypes)
-
-    val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) =>
-      val dType = rootType match {
-        case _: NullType => StringType
-        case other => other
+    val firstLine: String = CSVUtils.filterCommentAndEmpty(csv, options).first()
+    val firstRow = new CsvParser(options.asParserSettings).parseLine(firstLine)
+    val header = makeSafeHeader(firstRow, caseSensitive, options)
+
+    val fields = if (options.inferSchemaFlag) {
+      val tokenRdd = csv.rdd.mapPartitions { iter =>
+        val filteredLines = CSVUtils.filterCommentAndEmpty(iter, options)
+        val linesWithoutHeader = CSVUtils.filterHeaderLine(filteredLines, firstLine, options)
+        val parser = new CsvParser(options.asParserSettings)
+        linesWithoutHeader.map(parser.parseLine)
+      }
+
+      val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
+      val rootTypes: Array[DataType] =
+        tokenRdd.aggregate(startType)(inferRowType(options), mergeRowTypes)
+
+      header.zip(rootTypes).map { case (thisHeader, rootType) =>
+        val dType = rootType match {
+          case _: NullType => StringType
+          case other => other
+        }
+        StructField(thisHeader, dType, nullable = true)
       }
-      StructField(thisHeader, dType, nullable = true)
+    } else {
+      // By default fields are assumed to be StringType
+      header.map(fieldName => StructField(fieldName, StringType, nullable = true))
     }
 
-    StructType(structFields)
+    StructType(fields)
+  }
+
+  /**
+   * Generates a header from the given row which is null-safe and duplicate-safe.
+   */
+  private def makeSafeHeader(
+      row: Array[String],
+      caseSensitive: Boolean,
+      options: CSVOptions): Array[String] = {
+    if (options.headerFlag) {
+      val duplicates = {
+        val headerNames = row.filter(_ != null)
+          .map(name => if (caseSensitive) name else name.toLowerCase)
+        headerNames.diff(headerNames.distinct).distinct
+      }
+
+      row.zipWithIndex.map { case (value, index) =>
+        if (value == null || value.isEmpty || value == options.nullValue) {
+          // When there are empty strings or the values set in `nullValue`, put the
+          // index as the suffix.
+          s"_c$index"
+        } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) {
+          // When there are case-insensitive duplicates, put the index as the suffix.
+          s"$value$index"
+        } else if (duplicates.contains(value)) {
+          // When there are duplicates, put the index as the suffix.
+          s"$value$index"
+        } else {
+          value
+        }
+      }
+    } else {
+      row.zipWithIndex.map { case (_, index) =>
+        // Uses default column names, "_c#" where # is its position of fields
+        // when header option is disabled.
+        s"_c$index"
+      }
+    }
   }
 
   private def inferRowType(options: CSVOptions)
@@ -215,32 +267,3 @@ private[csv] object CSVInferSchema {
     case _ => None
   }
 }
-
-private[csv] object CSVTypeCast {
-  /**
-   * Helper method that converts string representation of a character to actual character.
-   * It handles some Java escaped strings and throws exception if given string is longer than one
-   * character.
-   */
-  @throws[IllegalArgumentException]
-  def toChar(str: String): Char = {
-    if (str.charAt(0) == '\\') {
-      str.charAt(1)
-      match {
-        case 't' => '\t'
-        case 'r' => '\r'
-        case 'b' => '\b'
-        case 'f' => '\f'
-        case '\"' => '\"' // In case user changes quote char and uses \" as delimiter in options
-        case '\'' => '\''
-        case 'u' if str == """\u0000""" => '\u0000'
-        case _ =>
-          throw new IllegalArgumentException(s"Unsupported special character for delimiter: $str")
-      }
-    } else if (str.length == 1) {
-      str.charAt(0)
-    } else {
-      throw new IllegalArgumentException(s"Delimiter cannot be more than one character: $str")
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/3d314d08/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index 140ce23..af456c8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -69,7 +69,7 @@ private[csv] class CSVOptions(@transient private val parameters: CaseInsensitive
     }
   }
 
-  val delimiter = CSVTypeCast.toChar(
+  val delimiter = CSVUtils.toChar(
     parameters.getOrElse("sep", parameters.getOrElse("delimiter", ",")))
   private val parseMode = parameters.getOrElse("mode", "PERMISSIVE")
   val charset = parameters.getOrElse("encoding",

http://git-wip-us.apache.org/repos/asf/spark/blob/3d314d08/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
deleted file mode 100644
index 19058c2..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala
+++ /dev/null
@@ -1,69 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.datasources.csv
-
-import com.univocity.parsers.csv.CsvParser
-
-import org.apache.spark.internal.Logging
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql._
-import org.apache.spark.sql.execution.datasources.PartitionedFile
-
-object CSVRelation extends Logging {
-
-  def univocityTokenizer(
-      file: Dataset[String],
-      firstLine: String,
-      params: CSVOptions): RDD[Array[String]] = {
-    // If header is set, make sure firstLine is materialized before sending to executors.
-    val commentPrefix = params.comment.toString
-    file.rdd.mapPartitions { iter =>
-      val parser = new CsvParser(params.asParserSettings)
-      val filteredIter = iter.filter { line =>
-        line.trim.nonEmpty && !line.startsWith(commentPrefix)
-      }
-      if (params.headerFlag) {
-        filteredIter.filterNot(_ == firstLine).map { item =>
-          parser.parseLine(item)
-        }
-      } else {
-        filteredIter.map { item =>
-          parser.parseLine(item)
-        }
-      }
-    }
-  }
-
-  // Skips the header line of each file if the `header` option is set to true.
-  def dropHeaderLine(
-      file: PartitionedFile, lines: Iterator[String], csvOptions: CSVOptions): Unit = {
-    // TODO What if the first partitioned file consists of only comments and empty lines?
-    if (csvOptions.headerFlag && file.start == 0) {
-      val nonEmptyLines = if (csvOptions.isCommentSet) {
-        val commentPrefix = csvOptions.comment.toString
-        lines.dropWhile { line =>
-          line.trim.isEmpty || line.trim.startsWith(commentPrefix)
-        }
-      } else {
-        lines.dropWhile(_.trim.isEmpty)
-      }
-
-      if (nonEmptyLines.hasNext) nonEmptyLines.drop(1)
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/3d314d08/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
new file mode 100644
index 0000000..72b053d
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.csv
+
+import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+
+object CSVUtils {
+  /**
+   * Filter ignorable rows for CSV dataset (lines empty and starting with `comment`).
+   * This is currently being used in CSV schema inference.
+   */
+  def filterCommentAndEmpty(lines: Dataset[String], options: CSVOptions): Dataset[String] = {
+    // Note that this was separately made by SPARK-18362. Logically, this should be the same
+    // with the one below, `filterCommentAndEmpty` but execution path is different. One of them
+    // might have to be removed in the near future if possible.
+    import lines.sqlContext.implicits._
+    val nonEmptyLines = lines.filter(length(trim($"value")) > 0)
+    if (options.isCommentSet) {
+      nonEmptyLines.filter(!$"value".startsWith(options.comment.toString))
+    } else {
+      nonEmptyLines
+    }
+  }
+
+  /**
+   * Filter ignorable rows for CSV iterator (lines empty and starting with `comment`).
+   * This is currently being used in CSV reading path and CSV schema inference.
+   */
+  def filterCommentAndEmpty(iter: Iterator[String], options: CSVOptions): Iterator[String] = {
+    iter.filter { line =>
+      line.trim.nonEmpty && !line.startsWith(options.comment.toString)
+    }
+  }
+
+  /**
+   * Skip the given first line so that only data can remain in a dataset.
+   * This is similar with `dropHeaderLine` below and currently being used in CSV schema inference.
+   */
+  def filterHeaderLine(
+       iter: Iterator[String],
+       firstLine: String,
+       options: CSVOptions): Iterator[String] = {
+    // Note that unlike actual CSV reading path, it simply filters the given first line. Therefore,
+    // this skips the line same with the header if exists. One of them might have to be removed
+    // in the near future if possible.
+    if (options.headerFlag) {
+      iter.filterNot(_ == firstLine)
+    } else {
+      iter
+    }
+  }
+
+  /**
+   * Drop header line so that only data can remain.
+   * This is similar with `filterHeaderLine` above and currently being used in CSV reading path.
+   */
+  def dropHeaderLine(iter: Iterator[String], options: CSVOptions): Iterator[String] = {
+    val nonEmptyLines = if (options.isCommentSet) {
+      val commentPrefix = options.comment.toString
+      iter.dropWhile { line =>
+        line.trim.isEmpty || line.trim.startsWith(commentPrefix)
+      }
+    } else {
+      iter.dropWhile(_.trim.isEmpty)
+    }
+
+    if (nonEmptyLines.hasNext) nonEmptyLines.drop(1)
+    iter
+  }
+
+  /**
+   * Helper method that converts string representation of a character to actual character.
+   * It handles some Java escaped strings and throws exception if given string is longer than one
+   * character.
+   */
+  @throws[IllegalArgumentException]
+  def toChar(str: String): Char = {
+    if (str.charAt(0) == '\\') {
+      str.charAt(1)
+      match {
+        case 't' => '\t'
+        case 'r' => '\r'
+        case 'b' => '\b'
+        case 'f' => '\f'
+        case '\"' => '\"' // In case user changes quote char and uses \" as delimiter in options
+        case '\'' => '\''
+        case 'u' if str == """\u0000""" => '\u0000'
+        case _ =>
+          throw new IllegalArgumentException(s"Unsupported special character for delimiter: $str")
+      }
+    } else if (str.length == 1) {
+      str.charAt(0)
+    } else {
+      throw new IllegalArgumentException(s"Delimiter cannot be more than one character: $str")
+    }
+  }
+
+  /**
+   * Verify if the schema is supported in CSV datasource.
+   */
+  def verifySchema(schema: StructType): Unit = {
+    def verifyType(dataType: DataType): Unit = dataType match {
+      case ByteType | ShortType | IntegerType | LongType | FloatType |
+           DoubleType | BooleanType | _: DecimalType | TimestampType |
+           DateType | StringType =>
+
+      case udt: UserDefinedType[_] => verifyType(udt.sqlType)
+
+      case _ =>
+        throw new UnsupportedOperationException(
+          s"CSV data source does not support ${dataType.simpleString} data type.")
+    }
+
+    schema.foreach(field => verifyType(field.dataType))
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/3d314d08/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
deleted file mode 100644
index 330d04d..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala
+++ /dev/null
@@ -1,329 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.datasources.json
-
-import java.util.Comparator
-
-import com.fasterxml.jackson.core._
-
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.analysis.TypeCoercion
-import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil
-import org.apache.spark.sql.catalyst.json.JSONOptions
-import org.apache.spark.sql.types._
-import org.apache.spark.util.Utils
-
-private[sql] object InferSchema {
-
-  /**
-   * Infer the type of a collection of json records in three stages:
-   *   1. Infer the type of each record
-   *   2. Merge types by choosing the lowest type necessary to cover equal keys
-   *   3. Replace any remaining null fields with string, the top type
-   */
-  def infer(
-      json: RDD[String],
-      columnNameOfCorruptRecord: String,
-      configOptions: JSONOptions): StructType = {
-    require(configOptions.samplingRatio > 0,
-      s"samplingRatio (${configOptions.samplingRatio}) should be greater than 0")
-    val shouldHandleCorruptRecord = configOptions.permissive
-    val schemaData = if (configOptions.samplingRatio > 0.99) {
-      json
-    } else {
-      json.sample(withReplacement = false, configOptions.samplingRatio, 1)
-    }
-
-    // perform schema inference on each row and merge afterwards
-    val rootType = schemaData.mapPartitions { iter =>
-      val factory = new JsonFactory()
-      configOptions.setJacksonOptions(factory)
-      iter.flatMap { row =>
-        try {
-          Utils.tryWithResource(factory.createParser(row)) { parser =>
-            parser.nextToken()
-            Some(inferField(parser, configOptions))
-          }
-        } catch {
-          case _: JsonParseException if shouldHandleCorruptRecord =>
-            Some(StructType(Seq(StructField(columnNameOfCorruptRecord, StringType))))
-          case _: JsonParseException =>
-            None
-        }
-      }
-    }.fold(StructType(Seq()))(
-      compatibleRootType(columnNameOfCorruptRecord, shouldHandleCorruptRecord))
-
-    canonicalizeType(rootType) match {
-      case Some(st: StructType) => st
-      case _ =>
-        // canonicalizeType erases all empty structs, including the only one we want to keep
-        StructType(Seq())
-    }
-  }
-
-  private[this] val structFieldComparator = new Comparator[StructField] {
-    override def compare(o1: StructField, o2: StructField): Int = {
-      o1.name.compare(o2.name)
-    }
-  }
-
-  private def isSorted(arr: Array[StructField]): Boolean = {
-    var i: Int = 0
-    while (i < arr.length - 1) {
-      if (structFieldComparator.compare(arr(i), arr(i + 1)) > 0) {
-        return false
-      }
-      i += 1
-    }
-    true
-  }
-
-  /**
-   * Infer the type of a json document from the parser's token stream
-   */
-  private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = {
-    import com.fasterxml.jackson.core.JsonToken._
-    parser.getCurrentToken match {
-      case null | VALUE_NULL => NullType
-
-      case FIELD_NAME =>
-        parser.nextToken()
-        inferField(parser, configOptions)
-
-      case VALUE_STRING if parser.getTextLength < 1 =>
-        // Zero length strings and nulls have special handling to deal
-        // with JSON generators that do not distinguish between the two.
-        // To accurately infer types for empty strings that are really
-        // meant to represent nulls we assume that the two are isomorphic
-        // but will defer treating null fields as strings until all the
-        // record fields' types have been combined.
-        NullType
-
-      case VALUE_STRING => StringType
-      case START_OBJECT =>
-        val builder = Array.newBuilder[StructField]
-        while (nextUntil(parser, END_OBJECT)) {
-          builder += StructField(
-            parser.getCurrentName,
-            inferField(parser, configOptions),
-            nullable = true)
-        }
-        val fields: Array[StructField] = builder.result()
-        // Note: other code relies on this sorting for correctness, so don't remove it!
-        java.util.Arrays.sort(fields, structFieldComparator)
-        StructType(fields)
-
-      case START_ARRAY =>
-        // If this JSON array is empty, we use NullType as a placeholder.
-        // If this array is not empty in other JSON objects, we can resolve
-        // the type as we pass through all JSON objects.
-        var elementType: DataType = NullType
-        while (nextUntil(parser, END_ARRAY)) {
-          elementType = compatibleType(
-            elementType, inferField(parser, configOptions))
-        }
-
-        ArrayType(elementType)
-
-      case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if configOptions.primitivesAsString => StringType
-
-      case (VALUE_TRUE | VALUE_FALSE) if configOptions.primitivesAsString => StringType
-
-      case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT =>
-        import JsonParser.NumberType._
-        parser.getNumberType match {
-          // For Integer values, use LongType by default.
-          case INT | LONG => LongType
-          // Since we do not have a data type backed by BigInteger,
-          // when we see a Java BigInteger, we use DecimalType.
-          case BIG_INTEGER | BIG_DECIMAL =>
-            val v = parser.getDecimalValue
-            if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) {
-              DecimalType(Math.max(v.precision(), v.scale()), v.scale())
-            } else {
-              DoubleType
-            }
-          case FLOAT | DOUBLE if configOptions.prefersDecimal =>
-            val v = parser.getDecimalValue
-            if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) {
-              DecimalType(Math.max(v.precision(), v.scale()), v.scale())
-            } else {
-              DoubleType
-            }
-          case FLOAT | DOUBLE =>
-            DoubleType
-        }
-
-      case VALUE_TRUE | VALUE_FALSE => BooleanType
-    }
-  }
-
-  /**
-   * Convert NullType to StringType and remove StructTypes with no fields
-   */
-  private def canonicalizeType(tpe: DataType): Option[DataType] = tpe match {
-    case at @ ArrayType(elementType, _) =>
-      for {
-        canonicalType <- canonicalizeType(elementType)
-      } yield {
-        at.copy(canonicalType)
-      }
-
-    case StructType(fields) =>
-      val canonicalFields: Array[StructField] = for {
-        field <- fields
-        if field.name.length > 0
-        canonicalType <- canonicalizeType(field.dataType)
-      } yield {
-        field.copy(dataType = canonicalType)
-      }
-
-      if (canonicalFields.length > 0) {
-        Some(StructType(canonicalFields))
-      } else {
-        // per SPARK-8093: empty structs should be deleted
-        None
-      }
-
-    case NullType => Some(StringType)
-    case other => Some(other)
-  }
-
-  private def withCorruptField(
-      struct: StructType,
-      columnNameOfCorruptRecords: String): StructType = {
-    if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) {
-      // If this given struct does not have a column used for corrupt records,
-      // add this field.
-      val newFields: Array[StructField] =
-        StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields
-      // Note: other code relies on this sorting for correctness, so don't remove it!
-      java.util.Arrays.sort(newFields, structFieldComparator)
-      StructType(newFields)
-    } else {
-      // Otherwise, just return this struct.
-      struct
-    }
-  }
-
-  /**
-   * Remove top-level ArrayType wrappers and merge the remaining schemas
-   */
-  private def compatibleRootType(
-      columnNameOfCorruptRecords: String,
-      shouldHandleCorruptRecord: Boolean): (DataType, DataType) => DataType = {
-    // Since we support array of json objects at the top level,
-    // we need to check the element type and find the root level data type.
-    case (ArrayType(ty1, _), ty2) =>
-      compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)(ty1, ty2)
-    case (ty1, ArrayType(ty2, _)) =>
-      compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)(ty1, ty2)
-    // If we see any other data type at the root level, we get records that cannot be
-    // parsed. So, we use the struct as the data type and add the corrupt field to the schema.
-    case (struct: StructType, NullType) => struct
-    case (NullType, struct: StructType) => struct
-    case (struct: StructType, o) if !o.isInstanceOf[StructType] && shouldHandleCorruptRecord =>
-      withCorruptField(struct, columnNameOfCorruptRecords)
-    case (o, struct: StructType) if !o.isInstanceOf[StructType] && shouldHandleCorruptRecord =>
-      withCorruptField(struct, columnNameOfCorruptRecords)
-    // If we get anything else, we call compatibleType.
-    // Usually, when we reach here, ty1 and ty2 are two StructTypes.
-    case (ty1, ty2) => compatibleType(ty1, ty2)
-  }
-
-  private[this] val emptyStructFieldArray = Array.empty[StructField]
-
-  /**
-   * Returns the most general data type for two given data types.
-   */
-  def compatibleType(t1: DataType, t2: DataType): DataType = {
-    TypeCoercion.findTightestCommonType(t1, t2).getOrElse {
-      // t1 or t2 is a StructType, ArrayType, or an unexpected type.
-      (t1, t2) match {
-        // Double support larger range than fixed decimal, DecimalType.Maximum should be enough
-        // in most case, also have better precision.
-        case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) =>
-          DoubleType
-
-        case (t1: DecimalType, t2: DecimalType) =>
-          val scale = math.max(t1.scale, t2.scale)
-          val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale)
-          if (range + scale > 38) {
-            // DecimalType can't support precision > 38
-            DoubleType
-          } else {
-            DecimalType(range + scale, scale)
-          }
-
-        case (StructType(fields1), StructType(fields2)) =>
-          // Both fields1 and fields2 should be sorted by name, since inferField performs sorting.
-          // Therefore, we can take advantage of the fact that we're merging sorted lists and skip
-          // building a hash map or performing additional sorting.
-          assert(isSorted(fields1), s"StructType's fields were not sorted: ${fields1.toSeq}")
-          assert(isSorted(fields2), s"StructType's fields were not sorted: ${fields2.toSeq}")
-
-          val newFields = new java.util.ArrayList[StructField]()
-
-          var f1Idx = 0
-          var f2Idx = 0
-
-          while (f1Idx < fields1.length && f2Idx < fields2.length) {
-            val f1Name = fields1(f1Idx).name
-            val f2Name = fields2(f2Idx).name
-            val comp = f1Name.compareTo(f2Name)
-            if (comp == 0) {
-              val dataType = compatibleType(fields1(f1Idx).dataType, fields2(f2Idx).dataType)
-              newFields.add(StructField(f1Name, dataType, nullable = true))
-              f1Idx += 1
-              f2Idx += 1
-            } else if (comp < 0) { // f1Name < f2Name
-              newFields.add(fields1(f1Idx))
-              f1Idx += 1
-            } else { // f1Name > f2Name
-              newFields.add(fields2(f2Idx))
-              f2Idx += 1
-            }
-          }
-          while (f1Idx < fields1.length) {
-            newFields.add(fields1(f1Idx))
-            f1Idx += 1
-          }
-          while (f2Idx < fields2.length) {
-            newFields.add(fields2(f2Idx))
-            f2Idx += 1
-          }
-          StructType(newFields.toArray(emptyStructFieldArray))
-
-        case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
-          ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
-
-        // The case that given `DecimalType` is capable of given `IntegralType` is handled in
-        // `findTightestCommonTypeOfTwo`. Both cases below will be executed only when
-        // the given `DecimalType` is not capable of the given `IntegralType`.
-        case (t1: IntegralType, t2: DecimalType) =>
-          compatibleType(DecimalType.forType(t1), t2)
-        case (t1: DecimalType, t2: IntegralType) =>
-          compatibleType(t1, DecimalType.forType(t2))
-
-        // strings and every string is a Json object.
-        case (_, _) => StringType
-      }
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/3d314d08/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
index be1f94d..98ab9d2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
@@ -51,7 +51,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
       val columnNameOfCorruptRecord =
         parsedOptions.columnNameOfCorruptRecord
           .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord)
-      val jsonSchema = InferSchema.infer(
+      val jsonSchema = JsonInferSchema.infer(
         createBaseRdd(sparkSession, files),
         columnNameOfCorruptRecord,
         parsedOptions)

http://git-wip-us.apache.org/repos/asf/spark/blob/3d314d08/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
new file mode 100644
index 0000000..f51c18d
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
@@ -0,0 +1,329 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.json
+
+import java.util.Comparator
+
+import com.fasterxml.jackson.core._
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.analysis.TypeCoercion
+import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil
+import org.apache.spark.sql.catalyst.json.JSONOptions
+import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
+
+private[sql] object JsonInferSchema {
+
+  /**
+   * Infer the type of a collection of json records in three stages:
+   *   1. Infer the type of each record
+   *   2. Merge types by choosing the lowest type necessary to cover equal keys
+   *   3. Replace any remaining null fields with string, the top type
+   */
+  def infer(
+      json: RDD[String],
+      columnNameOfCorruptRecord: String,
+      configOptions: JSONOptions): StructType = {
+    require(configOptions.samplingRatio > 0,
+      s"samplingRatio (${configOptions.samplingRatio}) should be greater than 0")
+    val shouldHandleCorruptRecord = configOptions.permissive
+    val schemaData = if (configOptions.samplingRatio > 0.99) {
+      json
+    } else {
+      json.sample(withReplacement = false, configOptions.samplingRatio, 1)
+    }
+
+    // perform schema inference on each row and merge afterwards
+    val rootType = schemaData.mapPartitions { iter =>
+      val factory = new JsonFactory()
+      configOptions.setJacksonOptions(factory)
+      iter.flatMap { row =>
+        try {
+          Utils.tryWithResource(factory.createParser(row)) { parser =>
+            parser.nextToken()
+            Some(inferField(parser, configOptions))
+          }
+        } catch {
+          case _: JsonParseException if shouldHandleCorruptRecord =>
+            Some(StructType(Seq(StructField(columnNameOfCorruptRecord, StringType))))
+          case _: JsonParseException =>
+            None
+        }
+      }
+    }.fold(StructType(Seq()))(
+      compatibleRootType(columnNameOfCorruptRecord, shouldHandleCorruptRecord))
+
+    canonicalizeType(rootType) match {
+      case Some(st: StructType) => st
+      case _ =>
+        // canonicalizeType erases all empty structs, including the only one we want to keep
+        StructType(Seq())
+    }
+  }
+
+  private[this] val structFieldComparator = new Comparator[StructField] {
+    override def compare(o1: StructField, o2: StructField): Int = {
+      o1.name.compare(o2.name)
+    }
+  }
+
+  private def isSorted(arr: Array[StructField]): Boolean = {
+    var i: Int = 0
+    while (i < arr.length - 1) {
+      if (structFieldComparator.compare(arr(i), arr(i + 1)) > 0) {
+        return false
+      }
+      i += 1
+    }
+    true
+  }
+
+  /**
+   * Infer the type of a json document from the parser's token stream
+   */
+  private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = {
+    import com.fasterxml.jackson.core.JsonToken._
+    parser.getCurrentToken match {
+      case null | VALUE_NULL => NullType
+
+      case FIELD_NAME =>
+        parser.nextToken()
+        inferField(parser, configOptions)
+
+      case VALUE_STRING if parser.getTextLength < 1 =>
+        // Zero length strings and nulls have special handling to deal
+        // with JSON generators that do not distinguish between the two.
+        // To accurately infer types for empty strings that are really
+        // meant to represent nulls we assume that the two are isomorphic
+        // but will defer treating null fields as strings until all the
+        // record fields' types have been combined.
+        NullType
+
+      case VALUE_STRING => StringType
+      case START_OBJECT =>
+        val builder = Array.newBuilder[StructField]
+        while (nextUntil(parser, END_OBJECT)) {
+          builder += StructField(
+            parser.getCurrentName,
+            inferField(parser, configOptions),
+            nullable = true)
+        }
+        val fields: Array[StructField] = builder.result()
+        // Note: other code relies on this sorting for correctness, so don't remove it!
+        java.util.Arrays.sort(fields, structFieldComparator)
+        StructType(fields)
+
+      case START_ARRAY =>
+        // If this JSON array is empty, we use NullType as a placeholder.
+        // If this array is not empty in other JSON objects, we can resolve
+        // the type as we pass through all JSON objects.
+        var elementType: DataType = NullType
+        while (nextUntil(parser, END_ARRAY)) {
+          elementType = compatibleType(
+            elementType, inferField(parser, configOptions))
+        }
+
+        ArrayType(elementType)
+
+      case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if configOptions.primitivesAsString => StringType
+
+      case (VALUE_TRUE | VALUE_FALSE) if configOptions.primitivesAsString => StringType
+
+      case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT =>
+        import JsonParser.NumberType._
+        parser.getNumberType match {
+          // For Integer values, use LongType by default.
+          case INT | LONG => LongType
+          // Since we do not have a data type backed by BigInteger,
+          // when we see a Java BigInteger, we use DecimalType.
+          case BIG_INTEGER | BIG_DECIMAL =>
+            val v = parser.getDecimalValue
+            if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) {
+              DecimalType(Math.max(v.precision(), v.scale()), v.scale())
+            } else {
+              DoubleType
+            }
+          case FLOAT | DOUBLE if configOptions.prefersDecimal =>
+            val v = parser.getDecimalValue
+            if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) {
+              DecimalType(Math.max(v.precision(), v.scale()), v.scale())
+            } else {
+              DoubleType
+            }
+          case FLOAT | DOUBLE =>
+            DoubleType
+        }
+
+      case VALUE_TRUE | VALUE_FALSE => BooleanType
+    }
+  }
+
+  /**
+   * Convert NullType to StringType and remove StructTypes with no fields
+   */
+  private def canonicalizeType(tpe: DataType): Option[DataType] = tpe match {
+    case at @ ArrayType(elementType, _) =>
+      for {
+        canonicalType <- canonicalizeType(elementType)
+      } yield {
+        at.copy(canonicalType)
+      }
+
+    case StructType(fields) =>
+      val canonicalFields: Array[StructField] = for {
+        field <- fields
+        if field.name.length > 0
+        canonicalType <- canonicalizeType(field.dataType)
+      } yield {
+        field.copy(dataType = canonicalType)
+      }
+
+      if (canonicalFields.length > 0) {
+        Some(StructType(canonicalFields))
+      } else {
+        // per SPARK-8093: empty structs should be deleted
+        None
+      }
+
+    case NullType => Some(StringType)
+    case other => Some(other)
+  }
+
+  private def withCorruptField(
+      struct: StructType,
+      columnNameOfCorruptRecords: String): StructType = {
+    if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) {
+      // If this given struct does not have a column used for corrupt records,
+      // add this field.
+      val newFields: Array[StructField] =
+        StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields
+      // Note: other code relies on this sorting for correctness, so don't remove it!
+      java.util.Arrays.sort(newFields, structFieldComparator)
+      StructType(newFields)
+    } else {
+      // Otherwise, just return this struct.
+      struct
+    }
+  }
+
+  /**
+   * Remove top-level ArrayType wrappers and merge the remaining schemas
+   */
+  private def compatibleRootType(
+      columnNameOfCorruptRecords: String,
+      shouldHandleCorruptRecord: Boolean): (DataType, DataType) => DataType = {
+    // Since we support array of json objects at the top level,
+    // we need to check the element type and find the root level data type.
+    case (ArrayType(ty1, _), ty2) =>
+      compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)(ty1, ty2)
+    case (ty1, ArrayType(ty2, _)) =>
+      compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)(ty1, ty2)
+    // If we see any other data type at the root level, we get records that cannot be
+    // parsed. So, we use the struct as the data type and add the corrupt field to the schema.
+    case (struct: StructType, NullType) => struct
+    case (NullType, struct: StructType) => struct
+    case (struct: StructType, o) if !o.isInstanceOf[StructType] && shouldHandleCorruptRecord =>
+      withCorruptField(struct, columnNameOfCorruptRecords)
+    case (o, struct: StructType) if !o.isInstanceOf[StructType] && shouldHandleCorruptRecord =>
+      withCorruptField(struct, columnNameOfCorruptRecords)
+    // If we get anything else, we call compatibleType.
+    // Usually, when we reach here, ty1 and ty2 are two StructTypes.
+    case (ty1, ty2) => compatibleType(ty1, ty2)
+  }
+
+  private[this] val emptyStructFieldArray = Array.empty[StructField]
+
+  /**
+   * Returns the most general data type for two given data types.
+   */
+  def compatibleType(t1: DataType, t2: DataType): DataType = {
+    TypeCoercion.findTightestCommonType(t1, t2).getOrElse {
+      // t1 or t2 is a StructType, ArrayType, or an unexpected type.
+      (t1, t2) match {
+        // Double support larger range than fixed decimal, DecimalType.Maximum should be enough
+        // in most case, also have better precision.
+        case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) =>
+          DoubleType
+
+        case (t1: DecimalType, t2: DecimalType) =>
+          val scale = math.max(t1.scale, t2.scale)
+          val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale)
+          if (range + scale > 38) {
+            // DecimalType can't support precision > 38
+            DoubleType
+          } else {
+            DecimalType(range + scale, scale)
+          }
+
+        case (StructType(fields1), StructType(fields2)) =>
+          // Both fields1 and fields2 should be sorted by name, since inferField performs sorting.
+          // Therefore, we can take advantage of the fact that we're merging sorted lists and skip
+          // building a hash map or performing additional sorting.
+          assert(isSorted(fields1), s"StructType's fields were not sorted: ${fields1.toSeq}")
+          assert(isSorted(fields2), s"StructType's fields were not sorted: ${fields2.toSeq}")
+
+          val newFields = new java.util.ArrayList[StructField]()
+
+          var f1Idx = 0
+          var f2Idx = 0
+
+          while (f1Idx < fields1.length && f2Idx < fields2.length) {
+            val f1Name = fields1(f1Idx).name
+            val f2Name = fields2(f2Idx).name
+            val comp = f1Name.compareTo(f2Name)
+            if (comp == 0) {
+              val dataType = compatibleType(fields1(f1Idx).dataType, fields2(f2Idx).dataType)
+              newFields.add(StructField(f1Name, dataType, nullable = true))
+              f1Idx += 1
+              f2Idx += 1
+            } else if (comp < 0) { // f1Name < f2Name
+              newFields.add(fields1(f1Idx))
+              f1Idx += 1
+            } else { // f1Name > f2Name
+              newFields.add(fields2(f2Idx))
+              f2Idx += 1
+            }
+          }
+          while (f1Idx < fields1.length) {
+            newFields.add(fields1(f1Idx))
+            f1Idx += 1
+          }
+          while (f2Idx < fields2.length) {
+            newFields.add(fields2(f2Idx))
+            f2Idx += 1
+          }
+          StructType(newFields.toArray(emptyStructFieldArray))
+
+        case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
+          ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
+
+        // The case that given `DecimalType` is capable of given `IntegralType` is handled in
+        // `findTightestCommonTypeOfTwo`. Both cases below will be executed only when
+        // the given `DecimalType` is not capable of the given `IntegralType`.
+        case (t1: IntegralType, t2: DecimalType) =>
+          compatibleType(DecimalType.forType(t1), t2)
+        case (t1: DecimalType, t2: IntegralType) =>
+          compatibleType(t1, DecimalType.forType(t2))
+
+        // strings and every string is a Json object.
+        case (_, _) => StringType
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/3d314d08/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtilsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtilsSuite.scala
new file mode 100644
index 0000000..221e44c
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtilsSuite.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.csv
+
+import org.apache.spark.SparkFunSuite
+
+class CSVUtilsSuite extends SparkFunSuite {
+  test("Can parse escaped characters") {
+    assert(CSVUtils.toChar("""\t""") === '\t')
+    assert(CSVUtils.toChar("""\r""") === '\r')
+    assert(CSVUtils.toChar("""\b""") === '\b')
+    assert(CSVUtils.toChar("""\f""") === '\f')
+    assert(CSVUtils.toChar("""\"""") === '\"')
+    assert(CSVUtils.toChar("""\'""") === '\'')
+    assert(CSVUtils.toChar("""\u0000""") === '\u0000')
+  }
+
+  test("Does not accept delimiter larger than one character") {
+    val exception = intercept[IllegalArgumentException]{
+      CSVUtils.toChar("ab")
+    }
+    assert(exception.getMessage.contains("cannot be more than one character"))
+  }
+
+  test("Throws exception for unsupported escaped characters") {
+    val exception = intercept[IllegalArgumentException]{
+      CSVUtils.toChar("""\1""")
+    }
+    assert(exception.getMessage.contains("Unsupported special character for delimiter"))
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/3d314d08/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala
index 2ca6308..62dae08 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala
@@ -43,30 +43,6 @@ class UnivocityParserSuite extends SparkFunSuite {
     }
   }
 
-  test("Can parse escaped characters") {
-    assert(CSVTypeCast.toChar("""\t""") === '\t')
-    assert(CSVTypeCast.toChar("""\r""") === '\r')
-    assert(CSVTypeCast.toChar("""\b""") === '\b')
-    assert(CSVTypeCast.toChar("""\f""") === '\f')
-    assert(CSVTypeCast.toChar("""\"""") === '\"')
-    assert(CSVTypeCast.toChar("""\'""") === '\'')
-    assert(CSVTypeCast.toChar("""\u0000""") === '\u0000')
-  }
-
-  test("Does not accept delimiter larger than one character") {
-    val exception = intercept[IllegalArgumentException]{
-      CSVTypeCast.toChar("ab")
-    }
-    assert(exception.getMessage.contains("cannot be more than one character"))
-  }
-
-  test("Throws exception for unsupported escaped characters") {
-    val exception = intercept[IllegalArgumentException]{
-      CSVTypeCast.toChar("""\1""")
-    }
-    assert(exception.getMessage.contains("Unsupported special character for delimiter"))
-  }
-
   test("Nullable types are handled") {
     val types = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType,
       BooleanType, DecimalType.DoubleDecimal, TimestampType, DateType, StringType)

http://git-wip-us.apache.org/repos/asf/spark/blob/3d314d08/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 161a409..156fd96 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions}
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.execution.datasources.DataSource
-import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType
+import org.apache.spark.sql.execution.datasources.json.JsonInferSchema.compatibleType
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
@@ -1366,7 +1366,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
 
   test("SPARK-6245 JsonRDD.inferSchema on empty RDD") {
     // This is really a test that it doesn't throw an exception
-    val emptySchema = InferSchema.infer(empty, "", new JSONOptions(Map.empty[String, String]))
+    val emptySchema = JsonInferSchema.infer(empty, "", new JSONOptions(Map.empty[String, String]))
     assert(StructType(Seq()) === emptySchema)
   }
 
@@ -1390,7 +1390,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
   }
 
   test("SPARK-8093 Erase empty structs") {
-    val emptySchema = InferSchema.infer(
+    val emptySchema = JsonInferSchema.infer(
       emptyRecords, "", new JSONOptions(Map.empty[String, String]))
     assert(StructType(Seq()) === emptySchema)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org