You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/03/15 02:19:28 UTC

spark git commit: [SPARK-19918][SQL] Use TextFileFormat in implementation of TextInputJsonDataSource

Repository: spark
Updated Branches:
  refs/heads/master dacc382f0 -> 8fb2a02e2


[SPARK-19918][SQL] Use TextFileFormat in implementation of TextInputJsonDataSource

## What changes were proposed in this pull request?

This PR proposes to use text datasource when Json schema inference.

This basically proposes the similar approach in https://github.com/apache/spark/pull/15813 If we use Dataset for initial loading when inferring the schema, there are advantages. Please refer SPARK-18362

It seems JSON one was supposed to be fixed together but taken out according to https://github.com/apache/spark/pull/15813

> A similar problem also affects the JSON file format and this patch originally fixed that as well, but I've decided to split that change into a separate patch so as not to conflict with changes in another JSON PR.

Also, this seems affecting some functionalities because it does not use `FileScanRDD`. This problem is described in SPARK-19885 (but it was CSV's case).

## How was this patch tested?

Existing tests should cover this and manual test by `spark.read.json(path)` and check the UI.

Author: hyukjinkwon <gu...@gmail.com>

Closes #17255 from HyukjinKwon/json-filescanrdd.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8fb2a02e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8fb2a02e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8fb2a02e

Branch: refs/heads/master
Commit: 8fb2a02e2ce6832e3d9338a7d0148dfac9fa24c2
Parents: dacc382
Author: hyukjinkwon <gu...@gmail.com>
Authored: Wed Mar 15 10:19:19 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Wed Mar 15 10:19:19 2017 +0800

----------------------------------------------------------------------
 .../org/apache/spark/sql/DataFrameReader.scala  |   9 +-
 .../datasources/json/JsonDataSource.scala       | 145 +++++++++----------
 .../datasources/json/JsonFileFormat.scala       |   2 +-
 .../datasources/json/JsonInferSchema.scala      |   9 +-
 .../execution/datasources/json/JsonUtils.scala  |  51 +++++++
 5 files changed, 122 insertions(+), 94 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8fb2a02e/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index f1bce1a..309654c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.command.DDLUtils
 import org.apache.spark.sql.execution.datasources.csv._
 import org.apache.spark.sql.execution.datasources.DataSource
 import org.apache.spark.sql.execution.datasources.jdbc._
-import org.apache.spark.sql.execution.datasources.json.JsonInferSchema
+import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource
 import org.apache.spark.sql.types.{StringType, StructType}
 import org.apache.spark.unsafe.types.UTF8String
 
@@ -376,17 +376,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
       extraOptions.toMap,
       sparkSession.sessionState.conf.sessionLocalTimeZone,
       sparkSession.sessionState.conf.columnNameOfCorruptRecord)
-    val createParser = CreateJacksonParser.string _
 
     val schema = userSpecifiedSchema.getOrElse {
-      JsonInferSchema.infer(
-        jsonDataset.rdd,
-        parsedOptions,
-        createParser)
+      TextInputJsonDataSource.inferFromDataset(jsonDataset, parsedOptions)
     }
 
     verifyColumnNameOfCorruptRecord(schema, parsedOptions.columnNameOfCorruptRecord)
 
+    val createParser = CreateJacksonParser.string _
     val parsed = jsonDataset.rdd.mapPartitions { iter =>
       val parser = new JacksonParser(schema, parsedOptions)
       iter.flatMap(parser.parse(_, createParser, UTF8String.fromString))

http://git-wip-us.apache.org/repos/asf/spark/blob/8fb2a02e/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
index 18843bf..84f0266 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala
@@ -17,32 +17,30 @@
 
 package org.apache.spark.sql.execution.datasources.json
 
-import scala.reflect.ClassTag
-
 import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
 import com.google.common.io.ByteStreams
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.FileStatus
-import org.apache.hadoop.io.{LongWritable, Text}
+import org.apache.hadoop.io.Text
 import org.apache.hadoop.mapreduce.Job
-import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, TextInputFormat}
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
 
 import org.apache.spark.TaskContext
 import org.apache.spark.input.{PortableDataStream, StreamInputFormat}
 import org.apache.spark.rdd.{BinaryFileRDD, RDD}
-import org.apache.spark.sql.{AnalysisException, SparkSession}
+import org.apache.spark.sql.{AnalysisException, Dataset, Encoders, SparkSession}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions}
-import org.apache.spark.sql.execution.datasources.{CodecStreams, HadoopFileLinesReader, PartitionedFile}
+import org.apache.spark.sql.execution.datasources.{CodecStreams, DataSource, HadoopFileLinesReader, PartitionedFile}
+import org.apache.spark.sql.execution.datasources.text.TextFileFormat
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.Utils
 
 /**
  * Common functions for parsing JSON files
- * @tparam T A datatype containing the unparsed JSON, such as [[Text]] or [[String]]
  */
-abstract class JsonDataSource[T] extends Serializable {
+abstract class JsonDataSource extends Serializable {
   def isSplitable: Boolean
 
   /**
@@ -53,28 +51,12 @@ abstract class JsonDataSource[T] extends Serializable {
     file: PartitionedFile,
     parser: JacksonParser): Iterator[InternalRow]
 
-  /**
-   * Create an [[RDD]] that handles the preliminary parsing of [[T]] records
-   */
-  protected def createBaseRdd(
-    sparkSession: SparkSession,
-    inputPaths: Seq[FileStatus]): RDD[T]
-
-  /**
-   * A generic wrapper to invoke the correct [[JsonFactory]] method to allocate a [[JsonParser]]
-   * for an instance of [[T]]
-   */
-  def createParser(jsonFactory: JsonFactory, value: T): JsonParser
-
-  final def infer(
+  final def inferSchema(
       sparkSession: SparkSession,
       inputPaths: Seq[FileStatus],
       parsedOptions: JSONOptions): Option[StructType] = {
     if (inputPaths.nonEmpty) {
-      val jsonSchema = JsonInferSchema.infer(
-        createBaseRdd(sparkSession, inputPaths),
-        parsedOptions,
-        createParser)
+      val jsonSchema = infer(sparkSession, inputPaths, parsedOptions)
       checkConstraints(jsonSchema)
       Some(jsonSchema)
     } else {
@@ -82,6 +64,11 @@ abstract class JsonDataSource[T] extends Serializable {
     }
   }
 
+  protected def infer(
+      sparkSession: SparkSession,
+      inputPaths: Seq[FileStatus],
+      parsedOptions: JSONOptions): StructType
+
   /** Constraints to be imposed on schema to be stored. */
   private def checkConstraints(schema: StructType): Unit = {
     if (schema.fieldNames.length != schema.fieldNames.distinct.length) {
@@ -95,53 +82,46 @@ abstract class JsonDataSource[T] extends Serializable {
 }
 
 object JsonDataSource {
-  def apply(options: JSONOptions): JsonDataSource[_] = {
+  def apply(options: JSONOptions): JsonDataSource = {
     if (options.wholeFile) {
       WholeFileJsonDataSource
     } else {
       TextInputJsonDataSource
     }
   }
-
-  /**
-   * Create a new [[RDD]] via the supplied callback if there is at least one file to process,
-   * otherwise an [[org.apache.spark.rdd.EmptyRDD]] will be returned.
-   */
-  def createBaseRdd[T : ClassTag](
-      sparkSession: SparkSession,
-      inputPaths: Seq[FileStatus])(
-      fn: (Configuration, String) => RDD[T]): RDD[T] = {
-    val paths = inputPaths.map(_.getPath)
-
-    if (paths.nonEmpty) {
-      val job = Job.getInstance(sparkSession.sessionState.newHadoopConf())
-      FileInputFormat.setInputPaths(job, paths: _*)
-      fn(job.getConfiguration, paths.mkString(","))
-    } else {
-      sparkSession.sparkContext.emptyRDD[T]
-    }
-  }
 }
 
-object TextInputJsonDataSource extends JsonDataSource[Text] {
+object TextInputJsonDataSource extends JsonDataSource {
   override val isSplitable: Boolean = {
     // splittable if the underlying source is
     true
   }
 
-  override protected def createBaseRdd(
+  override def infer(
       sparkSession: SparkSession,
-      inputPaths: Seq[FileStatus]): RDD[Text] = {
-    JsonDataSource.createBaseRdd(sparkSession, inputPaths) {
-      case (conf, name) =>
-        sparkSession.sparkContext.newAPIHadoopRDD(
-          conf,
-          classOf[TextInputFormat],
-          classOf[LongWritable],
-          classOf[Text])
-          .setName(s"JsonLines: $name")
-          .values // get the text column
-    }
+      inputPaths: Seq[FileStatus],
+      parsedOptions: JSONOptions): StructType = {
+    val json: Dataset[String] = createBaseDataset(sparkSession, inputPaths)
+    inferFromDataset(json, parsedOptions)
+  }
+
+  def inferFromDataset(json: Dataset[String], parsedOptions: JSONOptions): StructType = {
+    val sampled: Dataset[String] = JsonUtils.sample(json, parsedOptions)
+    val rdd: RDD[UTF8String] = sampled.queryExecution.toRdd.map(_.getUTF8String(0))
+    JsonInferSchema.infer(rdd, parsedOptions, CreateJacksonParser.utf8String)
+  }
+
+  private def createBaseDataset(
+      sparkSession: SparkSession,
+      inputPaths: Seq[FileStatus]): Dataset[String] = {
+    val paths = inputPaths.map(_.getPath.toString)
+    sparkSession.baseRelationToDataFrame(
+      DataSource.apply(
+        sparkSession,
+        paths = paths,
+        className = classOf[TextFileFormat].getName
+      ).resolveRelation(checkFilesExist = false))
+      .select("value").as(Encoders.STRING)
   }
 
   override def readFile(
@@ -150,41 +130,48 @@ object TextInputJsonDataSource extends JsonDataSource[Text] {
       parser: JacksonParser): Iterator[InternalRow] = {
     val linesReader = new HadoopFileLinesReader(file, conf)
     Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close()))
-    linesReader.flatMap(parser.parse(_, createParser, textToUTF8String))
+    linesReader.flatMap(parser.parse(_, CreateJacksonParser.text, textToUTF8String))
   }
 
   private def textToUTF8String(value: Text): UTF8String = {
     UTF8String.fromBytes(value.getBytes, 0, value.getLength)
   }
-
-  override def createParser(jsonFactory: JsonFactory, value: Text): JsonParser = {
-    CreateJacksonParser.text(jsonFactory, value)
-  }
 }
 
-object WholeFileJsonDataSource extends JsonDataSource[PortableDataStream] {
+object WholeFileJsonDataSource extends JsonDataSource {
   override val isSplitable: Boolean = {
     false
   }
 
-  override protected def createBaseRdd(
+  override def infer(
+      sparkSession: SparkSession,
+      inputPaths: Seq[FileStatus],
+      parsedOptions: JSONOptions): StructType = {
+    val json: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths)
+    val sampled: RDD[PortableDataStream] = JsonUtils.sample(json, parsedOptions)
+    JsonInferSchema.infer(sampled, parsedOptions, createParser)
+  }
+
+  private def createBaseRdd(
       sparkSession: SparkSession,
       inputPaths: Seq[FileStatus]): RDD[PortableDataStream] = {
-    JsonDataSource.createBaseRdd(sparkSession, inputPaths) {
-      case (conf, name) =>
-        new BinaryFileRDD(
-          sparkSession.sparkContext,
-          classOf[StreamInputFormat],
-          classOf[String],
-          classOf[PortableDataStream],
-          conf,
-          sparkSession.sparkContext.defaultMinPartitions)
-          .setName(s"JsonFile: $name")
-          .values
-    }
+    val paths = inputPaths.map(_.getPath)
+    val job = Job.getInstance(sparkSession.sessionState.newHadoopConf())
+    val conf = job.getConfiguration
+    val name = paths.mkString(",")
+    FileInputFormat.setInputPaths(job, paths: _*)
+    new BinaryFileRDD(
+      sparkSession.sparkContext,
+      classOf[StreamInputFormat],
+      classOf[String],
+      classOf[PortableDataStream],
+      conf,
+      sparkSession.sparkContext.defaultMinPartitions)
+      .setName(s"JsonFile: $name")
+      .values
   }
 
-  override def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = {
+  private def createParser(jsonFactory: JsonFactory, record: PortableDataStream): JsonParser = {
     CreateJacksonParser.inputStream(
       jsonFactory,
       CodecStreams.createInputStreamWithCloseResource(record.getConfiguration, record.getPath()))

http://git-wip-us.apache.org/repos/asf/spark/blob/8fb2a02e/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
index 902fee5a..a9dd91e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
@@ -54,7 +54,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
       options,
       sparkSession.sessionState.conf.sessionLocalTimeZone,
       sparkSession.sessionState.conf.columnNameOfCorruptRecord)
-    JsonDataSource(parsedOptions).infer(
+    JsonDataSource(parsedOptions).inferSchema(
       sparkSession, files, parsedOptions)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/8fb2a02e/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
index ab09358..7475f8e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala
@@ -40,18 +40,11 @@ private[sql] object JsonInferSchema {
       json: RDD[T],
       configOptions: JSONOptions,
       createParser: (JsonFactory, T) => JsonParser): StructType = {
-    require(configOptions.samplingRatio > 0,
-      s"samplingRatio (${configOptions.samplingRatio}) should be greater than 0")
     val shouldHandleCorruptRecord = configOptions.permissive
     val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord
-    val schemaData = if (configOptions.samplingRatio > 0.99) {
-      json
-    } else {
-      json.sample(withReplacement = false, configOptions.samplingRatio, 1)
-    }
 
     // perform schema inference on each row and merge afterwards
-    val rootType = schemaData.mapPartitions { iter =>
+    val rootType = json.mapPartitions { iter =>
       val factory = new JsonFactory()
       configOptions.setJacksonOptions(factory)
       iter.flatMap { row =>

http://git-wip-us.apache.org/repos/asf/spark/blob/8fb2a02e/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonUtils.scala
new file mode 100644
index 0000000..d511594
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonUtils.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.json
+
+import org.apache.spark.input.PortableDataStream
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.catalyst.json.JSONOptions
+
+object JsonUtils {
+  /**
+   * Sample JSON dataset as configured by `samplingRatio`.
+   */
+  def sample(json: Dataset[String], options: JSONOptions): Dataset[String] = {
+    require(options.samplingRatio > 0,
+      s"samplingRatio (${options.samplingRatio}) should be greater than 0")
+    if (options.samplingRatio > 0.99) {
+      json
+    } else {
+      json.sample(withReplacement = false, options.samplingRatio, 1)
+    }
+  }
+
+  /**
+   * Sample JSON RDD as configured by `samplingRatio`.
+   */
+  def sample(json: RDD[PortableDataStream], options: JSONOptions): RDD[PortableDataStream] = {
+    require(options.samplingRatio > 0,
+      s"samplingRatio (${options.samplingRatio}) should be greater than 0")
+    if (options.samplingRatio > 0.99) {
+      json
+    } else {
+      json.sample(withReplacement = false, options.samplingRatio, 1)
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org