You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2017/03/21 04:43:17 UTC

spark git commit: [SPARK-19949][SQL] unify bad record handling in CSV and JSON

Repository: spark
Updated Branches:
  refs/heads/master 21e366aea -> 68d65fae7


[SPARK-19949][SQL] unify bad record handling in CSV and JSON

## What changes were proposed in this pull request?

Currently JSON and CSV have exactly the same logic about handling bad records, this PR tries to abstract it and put it in a upper level to reduce code duplication.

The overall idea is, we make the JSON and CSV parser to throw a BadRecordException, then the upper level, FailureSafeParser, handles bad records according to the parse mode.

Behavior changes:
1. with PERMISSIVE mode, if the number of tokens doesn't match the schema, previously CSV parser will treat it as a legal record and parse as many tokens as possible. After this PR, we treat it as an illegal record, and put the raw record string in a special column, but we still parse as many tokens as possible.
2. all logging is removed as they are not very useful in practice.

## How was this patch tested?

existing tests

Author: Wenchen Fan <we...@databricks.com>
Author: hyukjinkwon <gu...@gmail.com>
Author: Wenchen Fan <cl...@gmail.com>

Closes #17315 from cloud-fan/bad-record2.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/68d65fae
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/68d65fae
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/68d65fae

Branch: refs/heads/master
Commit: 68d65fae71e475ad811a9716098aca03a2af9532
Parents: 21e366a
Author: Wenchen Fan <we...@databricks.com>
Authored: Mon Mar 20 21:43:14 2017 -0700
Committer: Xiao Li <ga...@gmail.com>
Committed: Mon Mar 20 21:43:14 2017 -0700

----------------------------------------------------------------------
 R/pkg/inst/tests/testthat/test_sparkSQL.R       |   5 +-
 .../catalyst/expressions/jsonExpressions.scala  |   4 +-
 .../spark/sql/catalyst/json/JSONOptions.scala   |   2 +-
 .../spark/sql/catalyst/json/JacksonParser.scala | 122 +-----------
 .../sql/catalyst/util/FailureSafeParser.scala   |  80 ++++++++
 .../org/apache/spark/sql/DataFrameReader.scala  |  23 ++-
 .../datasources/csv/CSVDataSource.scala         |  17 +-
 .../datasources/csv/CSVFileFormat.scala         |   7 +-
 .../execution/datasources/csv/CSVOptions.scala  |   2 +-
 .../datasources/csv/UnivocityParser.scala       | 197 ++++++-------------
 .../datasources/json/JsonDataSource.scala       |  31 ++-
 .../datasources/json/JsonFileFormat.scala       |   7 +-
 .../execution/datasources/csv/CSVSuite.scala    |   2 +-
 .../execution/datasources/json/JsonSuite.scala  |   8 +-
 14 files changed, 222 insertions(+), 285 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/68d65fae/R/pkg/inst/tests/testthat/test_sparkSQL.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index cbc3569..394d1a0 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -1370,9 +1370,8 @@ test_that("column functions", {
   # passing option
   df <- as.DataFrame(list(list("col" = "{\"date\":\"21/10/2014\"}")))
   schema2 <- structType(structField("date", "date"))
-  expect_error(tryCatch(collect(select(df, from_json(df$col, schema2))),
-                        error = function(e) { stop(e) }),
-               paste0(".*(java.lang.NumberFormatException: For input string:).*"))
+  s <- collect(select(df, from_json(df$col, schema2)))
+  expect_equal(s[[1]][[1]], NA)
   s <- collect(select(df, from_json(df$col, schema2, dateFormat = "dd/MM/yyyy")))
   expect_is(s[[1]][[1]]$date, "Date")
   expect_equal(as.character(s[[1]][[1]]$date), "2014-10-21")

http://git-wip-us.apache.org/repos/asf/spark/blob/68d65fae/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
index e4e08a8..08af552 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
 import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.json._
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, ParseModes}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, GenericArrayData, ParseModes}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.Utils
@@ -583,7 +583,7 @@ case class JsonToStructs(
         CreateJacksonParser.utf8String,
         identity[UTF8String]))
     } catch {
-      case _: SparkSQLJsonProcessingException => null
+      case _: BadRecordException => null
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/68d65fae/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
index 5f222ec..355c26a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
@@ -65,7 +65,7 @@ private[sql] class JSONOptions(
   val allowBackslashEscapingAnyCharacter =
     parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false)
   val compressionCodec = parameters.get("compression").map(CompressionCodecs.getCodecClassName)
-  private val parseMode = parameters.getOrElse("mode", "PERMISSIVE")
+  val parseMode = parameters.getOrElse("mode", "PERMISSIVE")
   val columnNameOfCorruptRecord =
     parameters.getOrElse("columnNameOfCorruptRecord", defaultColumnNameOfCorruptRecord)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/68d65fae/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
index 9b80c0f..fdb7d88 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
@@ -32,17 +32,14 @@ import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.Utils
 
-private[sql] class SparkSQLJsonProcessingException(msg: String) extends RuntimeException(msg)
-
 /**
  * Constructs a parser for a given schema that translates a json string to an [[InternalRow]].
  */
 class JacksonParser(
     schema: StructType,
-    options: JSONOptions) extends Logging {
+    val options: JSONOptions) extends Logging {
 
   import JacksonUtils._
-  import ParseModes._
   import com.fasterxml.jackson.core.JsonToken._
 
   // A `ValueConverter` is responsible for converting a value from `JsonParser`
@@ -55,108 +52,6 @@ class JacksonParser(
   private val factory = new JsonFactory()
   options.setJacksonOptions(factory)
 
-  private val emptyRow: Seq[InternalRow] = Seq(new GenericInternalRow(schema.length))
-
-  private val corruptFieldIndex = schema.getFieldIndex(options.columnNameOfCorruptRecord)
-  corruptFieldIndex.foreach { corrFieldIndex =>
-    require(schema(corrFieldIndex).dataType == StringType)
-    require(schema(corrFieldIndex).nullable)
-  }
-
-  @transient
-  private[this] var isWarningPrinted: Boolean = false
-
-  @transient
-  private def printWarningForMalformedRecord(record: () => UTF8String): Unit = {
-    def sampleRecord: String = {
-      if (options.wholeFile) {
-        ""
-      } else {
-        s"Sample record: ${record()}\n"
-      }
-    }
-
-    def footer: String = {
-      s"""Code example to print all malformed records (scala):
-         |===================================================
-         |// The corrupted record exists in column ${options.columnNameOfCorruptRecord}.
-         |val parsedJson = spark.read.json("/path/to/json/file/test.json")
-         |
-       """.stripMargin
-    }
-
-    if (options.permissive) {
-      logWarning(
-        s"""Found at least one malformed record. The JSON reader will replace
-           |all malformed records with placeholder null in current $PERMISSIVE_MODE parser mode.
-           |To find out which corrupted records have been replaced with null, please use the
-           |default inferred schema instead of providing a custom schema.
-           |
-           |${sampleRecord ++ footer}
-           |
-         """.stripMargin)
-    } else if (options.dropMalformed) {
-      logWarning(
-        s"""Found at least one malformed record. The JSON reader will drop
-           |all malformed records in current $DROP_MALFORMED_MODE parser mode. To find out which
-           |corrupted records have been dropped, please switch the parser mode to $PERMISSIVE_MODE
-           |mode and use the default inferred schema.
-           |
-           |${sampleRecord ++ footer}
-           |
-         """.stripMargin)
-    }
-  }
-
-  @transient
-  private def printWarningIfWholeFile(): Unit = {
-    if (options.wholeFile && corruptFieldIndex.isDefined) {
-      logWarning(
-        s"""Enabling wholeFile mode and defining columnNameOfCorruptRecord may result
-           |in very large allocations or OutOfMemoryExceptions being raised.
-           |
-         """.stripMargin)
-    }
-  }
-
-  /**
-   * This function deals with the cases it fails to parse. This function will be called
-   * when exceptions are caught during converting. This functions also deals with `mode` option.
-   */
-  private def failedRecord(record: () => UTF8String): Seq[InternalRow] = {
-    corruptFieldIndex match {
-      case _ if options.failFast =>
-        if (options.wholeFile) {
-          throw new SparkSQLJsonProcessingException("Malformed line in FAILFAST mode")
-        } else {
-          throw new SparkSQLJsonProcessingException(s"Malformed line in FAILFAST mode: ${record()}")
-        }
-
-      case _ if options.dropMalformed =>
-        if (!isWarningPrinted) {
-          printWarningForMalformedRecord(record)
-          isWarningPrinted = true
-        }
-        Nil
-
-      case None =>
-        if (!isWarningPrinted) {
-          printWarningForMalformedRecord(record)
-          isWarningPrinted = true
-        }
-        emptyRow
-
-      case Some(corruptIndex) =>
-        if (!isWarningPrinted) {
-          printWarningIfWholeFile()
-          isWarningPrinted = true
-        }
-        val row = new GenericInternalRow(schema.length)
-        row.update(corruptIndex, record())
-        Seq(row)
-    }
-  }
-
   /**
    * Create a converter which converts the JSON documents held by the `JsonParser`
    * to a value according to a desired schema. This is a wrapper for the method
@@ -239,7 +134,7 @@ class JacksonParser(
             lowerCaseValue.equals("-inf")) {
             value.toFloat
           } else {
-            throw new SparkSQLJsonProcessingException(s"Cannot parse $value as FloatType.")
+            throw new RuntimeException(s"Cannot parse $value as FloatType.")
           }
       }
 
@@ -259,7 +154,7 @@ class JacksonParser(
             lowerCaseValue.equals("-inf")) {
             value.toDouble
           } else {
-            throw new SparkSQLJsonProcessingException(s"Cannot parse $value as DoubleType.")
+            throw new RuntimeException(s"Cannot parse $value as DoubleType.")
           }
       }
 
@@ -391,9 +286,8 @@ class JacksonParser(
 
     case token =>
       // We cannot parse this token based on the given data type. So, we throw a
-      // SparkSQLJsonProcessingException and this exception will be caught by
-      // `parse` method.
-      throw new SparkSQLJsonProcessingException(
+      // RuntimeException and this exception will be caught by `parse` method.
+      throw new RuntimeException(
         s"Failed to parse a value for data type $dataType (current token: $token).")
   }
 
@@ -466,14 +360,14 @@ class JacksonParser(
         parser.nextToken() match {
           case null => Nil
           case _ => rootConverter.apply(parser) match {
-            case null => throw new SparkSQLJsonProcessingException("Root converter returned null")
+            case null => throw new RuntimeException("Root converter returned null")
             case rows => rows
           }
         }
       }
     } catch {
-      case _: JsonProcessingException | _: SparkSQLJsonProcessingException =>
-        failedRecord(() => recordLiteral(record))
+      case e @ (_: RuntimeException | _: JsonProcessingException) =>
+        throw BadRecordException(() => recordLiteral(record), () => None, e)
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/68d65fae/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala
new file mode 100644
index 0000000..e8da10d
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.unsafe.types.UTF8String
+
+class FailureSafeParser[IN](
+    rawParser: IN => Seq[InternalRow],
+    mode: String,
+    schema: StructType,
+    columnNameOfCorruptRecord: String) {
+
+  private val corruptFieldIndex = schema.getFieldIndex(columnNameOfCorruptRecord)
+  private val actualSchema = StructType(schema.filterNot(_.name == columnNameOfCorruptRecord))
+  private val resultRow = new GenericInternalRow(schema.length)
+  private val nullResult = new GenericInternalRow(schema.length)
+
+  // This function takes 2 parameters: an optional partial result, and the bad record. If the given
+  // schema doesn't contain a field for corrupted record, we just return the partial result or a
+  // row with all fields null. If the given schema contains a field for corrupted record, we will
+  // set the bad record to this field, and set other fields according to the partial result or null.
+  private val toResultRow: (Option[InternalRow], () => UTF8String) => InternalRow = {
+    if (corruptFieldIndex.isDefined) {
+      (row, badRecord) => {
+        var i = 0
+        while (i < actualSchema.length) {
+          val from = actualSchema(i)
+          resultRow(schema.fieldIndex(from.name)) = row.map(_.get(i, from.dataType)).orNull
+          i += 1
+        }
+        resultRow(corruptFieldIndex.get) = badRecord()
+        resultRow
+      }
+    } else {
+      (row, _) => row.getOrElse(nullResult)
+    }
+  }
+
+  def parse(input: IN): Iterator[InternalRow] = {
+    try {
+      rawParser.apply(input).toIterator.map(row => toResultRow(Some(row), () => null))
+    } catch {
+      case e: BadRecordException if ParseModes.isPermissiveMode(mode) =>
+        Iterator(toResultRow(e.partialResult(), e.record))
+      case _: BadRecordException if ParseModes.isDropMalformedMode(mode) =>
+        Iterator.empty
+      case e: BadRecordException => throw e.cause
+    }
+  }
+}
+
+/**
+ * Exception thrown when the underlying parser meet a bad record and can't parse it.
+ * @param record a function to return the record that cause the parser to fail
+ * @param partialResult a function that returns an optional row, which is the partial result of
+ *                      parsing this bad record.
+ * @param cause the actual exception about why the record is bad and can't be parsed.
+ */
+case class BadRecordException(
+    record: () => UTF8String,
+    partialResult: () => Option[InternalRow],
+    cause: Throwable) extends Exception(cause)

http://git-wip-us.apache.org/repos/asf/spark/blob/68d65fae/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 88fbfb4..767a636 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -27,6 +27,7 @@ import org.apache.spark.Partition
 import org.apache.spark.annotation.InterfaceStability
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
+import org.apache.spark.sql.catalyst.util.FailureSafeParser
 import org.apache.spark.sql.execution.LogicalRDD
 import org.apache.spark.sql.execution.command.DDLUtils
 import org.apache.spark.sql.execution.datasources.csv._
@@ -382,11 +383,18 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
     }
 
     verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord)
+    val actualSchema =
+      StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
 
     val createParser = CreateJacksonParser.string _
     val parsed = jsonDataset.rdd.mapPartitions { iter =>
-      val parser = new JacksonParser(schema, parsedOptions)
-      iter.flatMap(parser.parse(_, createParser, UTF8String.fromString))
+      val rawParser = new JacksonParser(actualSchema, parsedOptions)
+      val parser = new FailureSafeParser[String](
+        input => rawParser.parse(input, createParser, UTF8String.fromString),
+        parsedOptions.parseMode,
+        schema,
+        parsedOptions.columnNameOfCorruptRecord)
+      iter.flatMap(parser.parse)
     }
 
     Dataset.ofRows(
@@ -435,14 +443,21 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
     }
 
     verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord)
+    val actualSchema =
+      StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
 
     val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine =>
       filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions))
     }.getOrElse(filteredLines.rdd)
 
     val parsed = linesWithoutHeader.mapPartitions { iter =>
-      val parser = new UnivocityParser(schema, parsedOptions)
-      iter.flatMap(line => parser.parse(line))
+      val rawParser = new UnivocityParser(actualSchema, parsedOptions)
+      val parser = new FailureSafeParser[String](
+        input => Seq(rawParser.parse(input)),
+        parsedOptions.parseMode,
+        schema,
+        parsedOptions.columnNameOfCorruptRecord)
+      iter.flatMap(parser.parse)
     }
 
     Dataset.ofRows(

http://git-wip-us.apache.org/repos/asf/spark/blob/68d65fae/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
index 35ff924..63af18e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
@@ -49,7 +49,7 @@ abstract class CSVDataSource extends Serializable {
       conf: Configuration,
       file: PartitionedFile,
       parser: UnivocityParser,
-      parsedOptions: CSVOptions): Iterator[InternalRow]
+      schema: StructType): Iterator[InternalRow]
 
   /**
    * Infers the schema from `inputPaths` files.
@@ -115,17 +115,17 @@ object TextInputCSVDataSource extends CSVDataSource {
       conf: Configuration,
       file: PartitionedFile,
       parser: UnivocityParser,
-      parsedOptions: CSVOptions): Iterator[InternalRow] = {
+      schema: StructType): Iterator[InternalRow] = {
     val lines = {
       val linesReader = new HadoopFileLinesReader(file, conf)
       Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
       linesReader.map { line =>
-        new String(line.getBytes, 0, line.getLength, parsedOptions.charset)
+        new String(line.getBytes, 0, line.getLength, parser.options.charset)
       }
     }
 
-    val shouldDropHeader = parsedOptions.headerFlag && file.start == 0
-    UnivocityParser.parseIterator(lines, shouldDropHeader, parser)
+    val shouldDropHeader = parser.options.headerFlag && file.start == 0
+    UnivocityParser.parseIterator(lines, shouldDropHeader, parser, schema)
   }
 
   override def infer(
@@ -192,11 +192,12 @@ object WholeFileCSVDataSource extends CSVDataSource {
       conf: Configuration,
       file: PartitionedFile,
       parser: UnivocityParser,
-      parsedOptions: CSVOptions): Iterator[InternalRow] = {
+      schema: StructType): Iterator[InternalRow] = {
     UnivocityParser.parseStream(
       CodecStreams.createInputStreamWithCloseResource(conf, file.filePath),
-      parsedOptions.headerFlag,
-      parser)
+      parser.options.headerFlag,
+      parser,
+      schema)
   }
 
   override def infer(

http://git-wip-us.apache.org/repos/asf/spark/blob/68d65fae/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
index 29c4145..eef43c7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
@@ -113,8 +113,11 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister {
 
     (file: PartitionedFile) => {
       val conf = broadcastedHadoopConf.value.value
-      val parser = new UnivocityParser(dataSchema, requiredSchema, parsedOptions)
-      CSVDataSource(parsedOptions).readFile(conf, file, parser, parsedOptions)
+      val parser = new UnivocityParser(
+        StructType(dataSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)),
+        StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord)),
+        parsedOptions)
+      CSVDataSource(parsedOptions).readFile(conf, file, parser, requiredSchema)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/68d65fae/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index 2632e87..f6c6b6f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -82,7 +82,7 @@ class CSVOptions(
 
   val delimiter = CSVUtils.toChar(
     parameters.getOrElse("sep", parameters.getOrElse("delimiter", ",")))
-  private val parseMode = parameters.getOrElse("mode", "PERMISSIVE")
+  val parseMode = parameters.getOrElse("mode", "PERMISSIVE")
   val charset = parameters.getOrElse("encoding",
     parameters.getOrElse("charset", StandardCharsets.UTF_8.name()))
 

http://git-wip-us.apache.org/repos/asf/spark/blob/68d65fae/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
index e42ea3f..263f77e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
@@ -30,14 +30,14 @@ import com.univocity.parsers.csv.CsvParser
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
-import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.catalyst.util.{BadRecordException, DateTimeUtils, FailureSafeParser}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
 class UnivocityParser(
     schema: StructType,
     requiredSchema: StructType,
-    private val options: CSVOptions) extends Logging {
+    val options: CSVOptions) extends Logging {
   require(requiredSchema.toSet.subsetOf(schema.toSet),
     "requiredSchema should be the subset of schema.")
 
@@ -46,39 +46,26 @@ class UnivocityParser(
   // A `ValueConverter` is responsible for converting the given value to a desired type.
   private type ValueConverter = String => Any
 
-  private val corruptFieldIndex = schema.getFieldIndex(options.columnNameOfCorruptRecord)
-  corruptFieldIndex.foreach { corrFieldIndex =>
-    require(schema(corrFieldIndex).dataType == StringType)
-    require(schema(corrFieldIndex).nullable)
-  }
-
-  private val dataSchema = StructType(schema.filter(_.name != options.columnNameOfCorruptRecord))
-
   private val tokenizer = new CsvParser(options.asParserSettings)
 
-  private var numMalformedRecords = 0
-
   private val row = new GenericInternalRow(requiredSchema.length)
 
-  // In `PERMISSIVE` parse mode, we should be able to put the raw malformed row into the field
-  // specified in `columnNameOfCorruptRecord`. The raw input is retrieved by this method.
-  private def getCurrentInput(): String = tokenizer.getContext.currentParsedContent().stripLineEnd
+  // Retrieve the raw record string.
+  private def getCurrentInput: UTF8String = {
+    UTF8String.fromString(tokenizer.getContext.currentParsedContent().stripLineEnd)
+  }
 
-  // This parser loads an `tokenIndexArr`-th position value in input tokens,
-  // then put the value in `row(rowIndexArr)`.
+  // This parser first picks some tokens from the input tokens, according to the required schema,
+  // then parse these tokens and put the values in a row, with the order specified by the required
+  // schema.
   //
   // For example, let's say there is CSV data as below:
   //
   //   a,b,c
   //   1,2,A
   //
-  // Also, let's say `columnNameOfCorruptRecord` is set to "_unparsed", `header` is `true`
-  // by user and the user selects "c", "b", "_unparsed" and "a" fields. In this case, we need
-  // to map those values below:
-  //
-  //   required schema - ["c", "b", "_unparsed", "a"]
-  //   CSV data schema - ["a", "b", "c"]
-  //   required CSV data schema - ["c", "b", "a"]
+  // So the CSV data schema is: ["a", "b", "c"]
+  // And let's say the required schema is: ["c", "b"]
   //
   // with the input tokens,
   //
@@ -86,45 +73,12 @@ class UnivocityParser(
   //
   // Each input token is placed in each output row's position by mapping these. In this case,
   //
-  //   output row - ["A", 2, null, 1]
-  //
-  // In more details,
-  // - `valueConverters`, input tokens - CSV data schema
-  //   `valueConverters` keeps the positions of input token indices (by its index) to each
-  //   value's converter (by its value) in an order of CSV data schema. In this case,
-  //   [string->int, string->int, string->string].
-  //
-  // - `tokenIndexArr`, input tokens - required CSV data schema
-  //   `tokenIndexArr` keeps the positions of input token indices (by its index) to reordered
-  //   fields given the required CSV data schema (by its value). In this case, [2, 1, 0].
-  //
-  // - `rowIndexArr`, input tokens - required schema
-  //   `rowIndexArr` keeps the positions of input token indices (by its index) to reordered
-  //   field indices given the required schema (by its value). In this case, [0, 1, 3].
+  //   output row - ["A", 2]
   private val valueConverters: Array[ValueConverter] =
-    dataSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray
-
-  // Only used to create both `tokenIndexArr` and `rowIndexArr`. This variable means
-  // the fields that we should try to convert.
-  private val reorderedFields = if (options.dropMalformed) {
-    // If `dropMalformed` is enabled, then it needs to parse all the values
-    // so that we can decide which row is malformed.
-    requiredSchema ++ schema.filterNot(requiredSchema.contains(_))
-  } else {
-    requiredSchema
-  }
+    schema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray
 
   private val tokenIndexArr: Array[Int] = {
-    reorderedFields
-      .filter(_.name != options.columnNameOfCorruptRecord)
-      .map(f => dataSchema.indexOf(f)).toArray
-  }
-
-  private val rowIndexArr: Array[Int] = if (corruptFieldIndex.isDefined) {
-    val corrFieldIndex = corruptFieldIndex.get
-    reorderedFields.indices.filter(_ != corrFieldIndex).toArray
-  } else {
-    reorderedFields.indices.toArray
+    requiredSchema.map(f => schema.indexOf(f)).toArray
   }
 
   /**
@@ -205,7 +159,7 @@ class UnivocityParser(
       }
 
     case _: StringType => (d: String) =>
-      nullSafeDatum(d, name, nullable, options)(UTF8String.fromString(_))
+      nullSafeDatum(d, name, nullable, options)(UTF8String.fromString)
 
     case udt: UserDefinedType[_] => (datum: String) =>
       makeConverter(name, udt.sqlType, nullable, options)
@@ -233,81 +187,41 @@ class UnivocityParser(
    * Parses a single CSV string and turns it into either one resulting row or no row (if the
    * the record is malformed).
    */
-  def parse(input: String): Option[InternalRow] = convert(tokenizer.parseLine(input))
-
-  private def convert(tokens: Array[String]): Option[InternalRow] = {
-    convertWithParseMode(tokens) { tokens =>
-      var i: Int = 0
-      while (i < tokenIndexArr.length) {
-        // It anyway needs to try to parse since it decides if this row is malformed
-        // or not after trying to cast in `DROPMALFORMED` mode even if the casted
-        // value is not stored in the row.
-        val from = tokenIndexArr(i)
-        val to = rowIndexArr(i)
-        val value = valueConverters(from).apply(tokens(from))
-        if (i < requiredSchema.length) {
-          row(to) = value
-        }
-        i += 1
-      }
-      row
-    }
-  }
-
-  private def convertWithParseMode(
-      tokens: Array[String])(convert: Array[String] => InternalRow): Option[InternalRow] = {
-    if (options.dropMalformed && dataSchema.length != tokens.length) {
-      if (numMalformedRecords < options.maxMalformedLogPerPartition) {
-        logWarning(s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}")
-      }
-      if (numMalformedRecords == options.maxMalformedLogPerPartition - 1) {
-        logWarning(
-          s"More than ${options.maxMalformedLogPerPartition} malformed records have been " +
-            "found on this partition. Malformed records from now on will not be logged.")
+  def parse(input: String): InternalRow = convert(tokenizer.parseLine(input))
+
+  private def convert(tokens: Array[String]): InternalRow = {
+    if (tokens.length != schema.length) {
+      // If the number of tokens doesn't match the schema, we should treat it as a malformed record.
+      // However, we still have chance to parse some of the tokens, by adding extra null tokens in
+      // the tail if the number is smaller, or by dropping extra tokens if the number is larger.
+      val checkedTokens = if (schema.length > tokens.length) {
+        tokens ++ new Array[String](schema.length - tokens.length)
+      } else {
+        tokens.take(schema.length)
       }
-      numMalformedRecords += 1
-      None
-    } else if (options.failFast && dataSchema.length != tokens.length) {
-      throw new RuntimeException(s"Malformed line in FAILFAST mode: " +
-        s"${tokens.mkString(options.delimiter.toString)}")
-    } else {
-      // If a length of parsed tokens is not equal to expected one, it makes the length the same
-      // with the expected. If the length is shorter, it adds extra tokens in the tail.
-      // If longer, it drops extra tokens.
-      //
-      // TODO: Revisit this; if a length of tokens does not match an expected length in the schema,
-      // we probably need to treat it as a malformed record.
-      // See an URL below for related discussions:
-      // https://github.com/apache/spark/pull/16928#discussion_r102657214
-      val checkedTokens = if (options.permissive && dataSchema.length != tokens.length) {
-        if (dataSchema.length > tokens.length) {
-          tokens ++ new Array[String](dataSchema.length - tokens.length)
-        } else {
-          tokens.take(dataSchema.length)
+      def getPartialResult(): Option[InternalRow] = {
+        try {
+          Some(convert(checkedTokens))
+        } catch {
+          case _: BadRecordException => None
         }
-      } else {
-        tokens
       }
-
+      throw BadRecordException(
+        () => getCurrentInput,
+        getPartialResult,
+        new RuntimeException("Malformed CSV record"))
+    } else {
       try {
-        Some(convert(checkedTokens))
+        var i = 0
+        while (i < requiredSchema.length) {
+          val from = tokenIndexArr(i)
+          row(i) = valueConverters(from).apply(tokens(from))
+          i += 1
+        }
+        row
       } catch {
-        case NonFatal(e) if options.permissive =>
-          val row = new GenericInternalRow(requiredSchema.length)
-          corruptFieldIndex.foreach(row(_) = UTF8String.fromString(getCurrentInput()))
-          Some(row)
-        case NonFatal(e) if options.dropMalformed =>
-          if (numMalformedRecords < options.maxMalformedLogPerPartition) {
-            logWarning("Parse exception. " +
-              s"Dropping malformed line: ${tokens.mkString(options.delimiter.toString)}")
-          }
-          if (numMalformedRecords == options.maxMalformedLogPerPartition - 1) {
-            logWarning(
-              s"More than ${options.maxMalformedLogPerPartition} malformed records have been " +
-                "found on this partition. Malformed records from now on will not be logged.")
-          }
-          numMalformedRecords += 1
-          None
+        case NonFatal(e) =>
+          throw BadRecordException(() => getCurrentInput, () => None, e)
       }
     }
   }
@@ -331,10 +245,16 @@ private[csv] object UnivocityParser {
   def parseStream(
       inputStream: InputStream,
       shouldDropHeader: Boolean,
-      parser: UnivocityParser): Iterator[InternalRow] = {
+      parser: UnivocityParser,
+      schema: StructType): Iterator[InternalRow] = {
     val tokenizer = parser.tokenizer
+    val safeParser = new FailureSafeParser[Array[String]](
+      input => Seq(parser.convert(input)),
+      parser.options.parseMode,
+      schema,
+      parser.options.columnNameOfCorruptRecord)
     convertStream(inputStream, shouldDropHeader, tokenizer) { tokens =>
-      parser.convert(tokens)
+      safeParser.parse(tokens)
     }.flatten
   }
 
@@ -368,7 +288,8 @@ private[csv] object UnivocityParser {
   def parseIterator(
       lines: Iterator[String],
       shouldDropHeader: Boolean,
-      parser: UnivocityParser): Iterator[InternalRow] = {
+      parser: UnivocityParser,
+      schema: StructType): Iterator[InternalRow] = {
     val options = parser.options
 
     val linesWithoutHeader = if (shouldDropHeader) {
@@ -381,6 +302,12 @@ private[csv] object UnivocityParser {
 
     val filteredLines: Iterator[String] =
       CSVUtils.filterCommentAndEmpty(linesWithoutHeader, options)
-    filteredLines.flatMap(line => parser.parse(line))
+
+    val safeParser = new FailureSafeParser[String](
+      input => Seq(parser.parse(input)),
+      parser.options.parseMode,
+      schema,
+      parser.options.columnNameOfCorruptRecord)
+    filteredLines.flatMap(safeParser.parse)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/68d65fae/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
index 84f0266..51e952c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution.datasources.json
 
+import java.io.InputStream
+
 import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
 import com.google.common.io.ByteStreams
 import org.apache.hadoop.conf.Configuration
@@ -31,6 +33,7 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD}
 import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
+import org.apache.spark.sql.catalyst.util.FailureSafeParser
 import org.apache.spark.sql.execution.datasources.{CodecStreams, DataSource, HadoopFileLinesReader, PartitionedFile}
 import org.apache.spark.sql.execution.datasources.text.TextFileFormat
 import org.apache.spark.sql.types.StructType
@@ -49,7 +52,8 @@ abstract class JsonDataSource extends Serializable {
   def readFile(
     conf: Configuration,
     file: PartitionedFile,
-    parser: JacksonParser): Iterator[InternalRow]
+    parser: JacksonParser,
+    schema: StructType): Iterator[InternalRow]
 
   final def inferSchema(
       sparkSession: SparkSession,
@@ -127,10 +131,16 @@ object TextInputJsonDataSource extends JsonDataSource {
   override def readFile(
       conf: Configuration,
       file: PartitionedFile,
-      parser: JacksonParser): Iterator[InternalRow] = {
+      parser: JacksonParser,
+      schema: StructType): Iterator[InternalRow] = {
     val linesReader = new HadoopFileLinesReader(file, conf)
     Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
-    linesReader.flatMap(parser.parse(_, CreateJacksonParser.text, textToUTF8String))
+    val safeParser = new FailureSafeParser[Text](
+      input => parser.parse(input, CreateJacksonParser.text, textToUTF8String),
+      parser.options.parseMode,
+      schema,
+      parser.options.columnNameOfCorruptRecord)
+    linesReader.flatMap(safeParser.parse)
   }
 
   private def textToUTF8String(value: Text): UTF8String = {
@@ -180,7 +190,8 @@ object WholeFileJsonDataSource extends JsonDataSource {
   override def readFile(
       conf: Configuration,
       file: PartitionedFile,
-      parser: JacksonParser): Iterator[InternalRow] = {
+      parser: JacksonParser,
+      schema: StructType): Iterator[InternalRow] = {
     def partitionedFileString(ignored: Any): UTF8String = {
       Utils.tryWithResource {
         CodecStreams.createInputStreamWithCloseResource(conf, file.filePath)
@@ -189,9 +200,13 @@ object WholeFileJsonDataSource extends JsonDataSource {
       }
     }
 
-    parser.parse(
-      CodecStreams.createInputStreamWithCloseResource(conf, file.filePath),
-      CreateJacksonParser.inputStream,
-      partitionedFileString).toIterator
+    val safeParser = new FailureSafeParser[InputStream](
+      input => parser.parse(input, CreateJacksonParser.inputStream, partitionedFileString),
+      parser.options.parseMode,
+      schema,
+      parser.options.columnNameOfCorruptRecord)
+
+    safeParser.parse(
+      CodecStreams.createInputStreamWithCloseResource(conf, file.filePath))
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/68d65fae/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
index a9dd91e..53d62d8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
@@ -102,6 +102,8 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
       sparkSession.sessionState.conf.sessionLocalTimeZone,
       sparkSession.sessionState.conf.columnNameOfCorruptRecord)
 
+    val actualSchema =
+      StructType(requiredSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
     // Check a field requirement for corrupt records here to throw an exception in a driver side
     dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex =>
       val f = dataSchema(corruptFieldIndex)
@@ -112,11 +114,12 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
     }
 
     (file: PartitionedFile) => {
-      val parser = new JacksonParser(requiredSchema, parsedOptions)
+      val parser = new JacksonParser(actualSchema, parsedOptions)
       JsonDataSource(parsedOptions).readFile(
         broadcastedHadoopConf.value.value,
         file,
-        parser)
+        parser,
+        requiredSchema)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/68d65fae/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 95dfdf5..598babf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -293,7 +293,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
           .load(testFile(carsFile)).collect()
       }
 
-      assert(exception.getMessage.contains("Malformed line in FAILFAST mode: 2015,Chevy,Volt"))
+      assert(exception.getMessage.contains("Malformed CSV record"))
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/68d65fae/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 9b0efcb..56fcf773 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -1043,7 +1043,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
         .json(corruptRecords)
         .collect()
     }
-    assert(exceptionOne.getMessage.contains("Malformed line in FAILFAST mode: {"))
+    assert(exceptionOne.getMessage.contains("JsonParseException"))
 
     val exceptionTwo = intercept[SparkException] {
       spark.read
@@ -1052,7 +1052,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
         .json(corruptRecords)
         .collect()
     }
-    assert(exceptionTwo.getMessage.contains("Malformed line in FAILFAST mode: {"))
+    assert(exceptionTwo.getMessage.contains("JsonParseException"))
   }
 
   test("Corrupt records: DROPMALFORMED mode") {
@@ -1929,7 +1929,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
           .json(path)
           .collect()
       }
-      assert(exceptionOne.getMessage.contains("Malformed line in FAILFAST mode"))
+      assert(exceptionOne.getMessage.contains("Failed to parse a value"))
 
       val exceptionTwo = intercept[SparkException] {
         spark.read
@@ -1939,7 +1939,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
           .json(path)
           .collect()
       }
-      assert(exceptionTwo.getMessage.contains("Malformed line in FAILFAST mode"))
+      assert(exceptionTwo.getMessage.contains("Failed to parse a value"))
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org