You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by td...@apache.org on 2016/02/10 03:50:10 UTC

spark git commit: [SPARK-13149][SQL] Add FileStreamSource

Repository: spark
Updated Branches:
  refs/heads/master 6f710f9fd -> b385ce388


[SPARK-13149][SQL] Add FileStreamSource

`FileStreamSource` is an implementation of `org.apache.spark.sql.execution.streaming.Source`. It takes advantage of the existing `HadoopFsRelationProvider` to support various file formats. It remembers files in each batch and stores it into the metadata files so as to recover them when restarting. The metadata files are stored in the file system. There will be a further PR to clean up the metadata files periodically.

This is based on the initial work from marmbrus.

Author: Shixiong Zhu <sh...@databricks.com>

Closes #11034 from zsxwing/stream-df-file-source.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b385ce38
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b385ce38
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b385ce38

Branch: refs/heads/master
Commit: b385ce38825de4b1420c5a0e8191e91fc8afecf5
Parents: 6f710f9
Author: Shixiong Zhu <sh...@databricks.com>
Authored: Tue Feb 9 18:50:06 2016 -0800
Committer: Tathagata Das <ta...@gmail.com>
Committed: Tue Feb 9 18:50:06 2016 -0800

----------------------------------------------------------------------
 .../datasources/ResolvedDataSource.scala        |   2 +-
 .../execution/streaming/FileStreamSource.scala  | 240 ++++++++++
 .../apache/spark/sql/sources/interfaces.scala   |  33 +-
 .../scala/org/apache/spark/sql/StreamTest.scala |   2 +
 .../streaming/DataFrameReaderWriterSuite.scala  |   5 +-
 .../sql/streaming/FileStreamSourceSuite.scala   | 435 +++++++++++++++++++
 6 files changed, 710 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b385ce38/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala
index 7702f53..cefa8be 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala
@@ -104,7 +104,7 @@ object ResolvedDataSource extends Logging {
           s"Data source $providerName does not support streamed reading")
     }
 
-    provider.createSource(sqlContext, options, userSpecifiedSchema)
+    provider.createSource(sqlContext, userSpecifiedSchema, providerName, options)
   }
 
   def createSink(

http://git-wip-us.apache.org/repos/asf/spark/blob/b385ce38/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
new file mode 100644
index 0000000..14ba9f6
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
@@ -0,0 +1,240 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io._
+
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.io.Codec
+
+import com.google.common.base.Charsets.UTF_8
+import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.types.{StringType, StructType}
+import org.apache.spark.util.collection.OpenHashSet
+
+/**
+ * A very simple source that reads text files from the given directory as they appear.
+ *
+ * TODO Clean up the metadata files periodically
+ */
+class FileStreamSource(
+    sqlContext: SQLContext,
+    metadataPath: String,
+    path: String,
+    dataSchema: Option[StructType],
+    providerName: String,
+    dataFrameBuilder: Array[String] => DataFrame) extends Source with Logging {
+
+  private val fs = FileSystem.get(sqlContext.sparkContext.hadoopConfiguration)
+  private var maxBatchId = -1
+  private val seenFiles = new OpenHashSet[String]
+
+  /** Map of batch id to files. This map is also stored in `metadataPath`. */
+  private val batchToMetadata = new HashMap[Long, Seq[String]]
+
+  {
+    // Restore file paths from the metadata files
+    val existingBatchFiles = fetchAllBatchFiles()
+    if (existingBatchFiles.nonEmpty) {
+      val existingBatchIds = existingBatchFiles.map(_.getPath.getName.toInt)
+      maxBatchId = existingBatchIds.max
+      // Recover "batchToMetadata" and "seenFiles" from existing metadata files.
+      existingBatchIds.sorted.foreach { batchId =>
+        val files = readBatch(batchId)
+        if (files.isEmpty) {
+          // Assert that the corrupted file must be the latest metadata file.
+          if (batchId != maxBatchId) {
+            throw new IllegalStateException("Invalid metadata files")
+          }
+          maxBatchId = maxBatchId - 1
+        } else {
+          batchToMetadata(batchId) = files
+          files.foreach(seenFiles.add)
+        }
+      }
+    }
+  }
+
+  /** Returns the schema of the data from this source */
+  override lazy val schema: StructType = {
+    dataSchema.getOrElse {
+      val filesPresent = fetchAllFiles()
+      if (filesPresent.isEmpty) {
+        if (providerName == "text") {
+          // Add a default schema for "text"
+          new StructType().add("value", StringType)
+        } else {
+          throw new IllegalArgumentException("No schema specified")
+        }
+      } else {
+        // There are some existing files. Use them to infer the schema.
+        dataFrameBuilder(filesPresent.toArray).schema
+      }
+    }
+  }
+
+  /**
+   * Returns the maximum offset that can be retrieved from the source.
+   *
+   * `synchronized` on this method is for solving race conditions in tests. In the normal usage,
+   * there is no race here, so the cost of `synchronized` should be rare.
+   */
+  private def fetchMaxOffset(): LongOffset = synchronized {
+    val filesPresent = fetchAllFiles()
+    val newFiles = new ArrayBuffer[String]()
+    filesPresent.foreach { file =>
+      if (!seenFiles.contains(file)) {
+        logDebug(s"new file: $file")
+        newFiles.append(file)
+        seenFiles.add(file)
+      } else {
+        logDebug(s"old file: $file")
+      }
+    }
+
+    if (newFiles.nonEmpty) {
+      maxBatchId += 1
+      writeBatch(maxBatchId, newFiles)
+    }
+
+    new LongOffset(maxBatchId)
+  }
+
+  /**
+   * For test only. Run `func` with the internal lock to make sure when `func` is running,
+   * the current offset won't be changed and no new batch will be emitted.
+   */
+  def withBatchingLocked[T](func: => T): T = synchronized {
+    func
+  }
+
+  /** Return the latest offset in the source */
+  def currentOffset: LongOffset = synchronized {
+    new LongOffset(maxBatchId)
+  }
+
+  /**
+   * Returns the next batch of data that is available after `start`, if any is available.
+   */
+  override def getNextBatch(start: Option[Offset]): Option[Batch] = {
+    val startId = start.map(_.asInstanceOf[LongOffset].offset).getOrElse(-1L)
+    val end = fetchMaxOffset()
+    val endId = end.offset
+
+    if (startId + 1 <= endId) {
+      val files = (startId + 1 to endId).filter(_ >= 0).flatMap { batchId =>
+          batchToMetadata.getOrElse(batchId, Nil)
+        }.toArray
+      logDebug(s"Return files from batches ${startId + 1}:$endId")
+      logDebug(s"Streaming ${files.mkString(", ")}")
+      Some(new Batch(end, dataFrameBuilder(files)))
+    }
+    else {
+      None
+    }
+  }
+
+  private def fetchAllBatchFiles(): Seq[FileStatus] = {
+    try fs.listStatus(new Path(metadataPath)) catch {
+      case _: java.io.FileNotFoundException =>
+        fs.mkdirs(new Path(metadataPath))
+        Seq.empty
+    }
+  }
+
+  private def fetchAllFiles(): Seq[String] = {
+    fs.listStatus(new Path(path))
+      .filterNot(_.getPath.getName.startsWith("_"))
+      .map(_.getPath.toUri.toString)
+  }
+
+  /**
+   * Write the metadata of a batch to disk. The file format is as follows:
+   *
+   * {{{
+   *   <FileStreamSource.VERSION>
+   *   START
+   *   -/a/b/c
+   *   -/d/e/f
+   *   ...
+   *   END
+   * }}}
+   *
+   * Note: <FileStreamSource.VERSION> means the value of `FileStreamSource.VERSION`. Every file
+   * path starts with "-" so that we can know if a line is a file path easily.
+   */
+  private def writeBatch(id: Int, files: Seq[String]): Unit = {
+    assert(files.nonEmpty, "create a new batch without any file")
+    val output = fs.create(new Path(metadataPath + "/" + id), true)
+    val writer = new PrintWriter(new OutputStreamWriter(output, UTF_8))
+    try {
+      // scalastyle:off println
+      writer.println(FileStreamSource.VERSION)
+      writer.println(FileStreamSource.START_TAG)
+      files.foreach(file => writer.println(FileStreamSource.PATH_PREFIX + file))
+      writer.println(FileStreamSource.END_TAG)
+      // scalastyle:on println
+    } finally {
+      writer.close()
+    }
+    batchToMetadata(id) = files
+  }
+
+  /** Read the file names of the specified batch id from the metadata file */
+  private def readBatch(id: Int): Seq[String] = {
+    val input = fs.open(new Path(metadataPath + "/" + id))
+    try {
+      FileStreamSource.readBatch(input)
+    } finally {
+      input.close()
+    }
+  }
+}
+
+object FileStreamSource {
+
+  private val START_TAG = "START"
+  private val END_TAG = "END"
+  private val PATH_PREFIX = "-"
+  val VERSION = "FILESTREAM_V1"
+
+  /**
+   * Parse a metadata file and return the content. If the metadata file is corrupted, it will return
+   * an empty `Seq`.
+   */
+  def readBatch(input: InputStream): Seq[String] = {
+    val lines = scala.io.Source.fromInputStream(input)(Codec.UTF8).getLines().toArray
+    if (lines.length < 4) {
+      // version + start tag + end tag + at least one file path
+      return Nil
+    }
+    if (lines.head != VERSION) {
+      return Nil
+    }
+    if (lines(1) != START_TAG) {
+      return Nil
+    }
+    if (lines.last != END_TAG) {
+      return Nil
+    }
+    lines.slice(2, lines.length - 1).map(_.stripPrefix(PATH_PREFIX)) // Drop character "-"
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/b385ce38/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 737be7d..428a313 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
 import org.apache.spark.sql.execution.{FileRelation, RDDConversions}
 import org.apache.spark.sql.execution.datasources._
-import org.apache.spark.sql.execution.streaming.{Sink, Source}
+import org.apache.spark.sql.execution.streaming.{FileStreamSource, Sink, Source}
 import org.apache.spark.sql.types.{StringType, StructType}
 import org.apache.spark.util.SerializableConfiguration
 import org.apache.spark.util.collection.BitSet
@@ -131,8 +131,9 @@ trait SchemaRelationProvider {
 trait StreamSourceProvider {
   def createSource(
       sqlContext: SQLContext,
-      parameters: Map[String, String],
-      schema: Option[StructType]): Source
+      schema: Option[StructType],
+      providerName: String,
+      parameters: Map[String, String]): Source
 }
 
 /**
@@ -169,7 +170,7 @@ trait StreamSinkProvider {
  * @since 1.4.0
  */
 @Experimental
-trait HadoopFsRelationProvider {
+trait HadoopFsRelationProvider extends StreamSourceProvider {
   /**
    * Returns a new base relation with the given parameters, a user defined schema, and a list of
    * partition columns. Note: the parameters' keywords are case insensitive and this insensitivity
@@ -196,6 +197,30 @@ trait HadoopFsRelationProvider {
     }
     createRelation(sqlContext, paths, dataSchema, partitionColumns, parameters)
   }
+
+  override def createSource(
+      sqlContext: SQLContext,
+      schema: Option[StructType],
+      providerName: String,
+      parameters: Map[String, String]): Source = {
+    val path = parameters.getOrElse("path", {
+      throw new IllegalArgumentException("'path' is not specified")
+    })
+    val metadataPath = parameters.getOrElse("metadataPath", s"$path/_metadata")
+
+    def dataFrameBuilder(files: Array[String]): DataFrame = {
+      val relation = createRelation(
+        sqlContext,
+        files,
+        schema,
+        partitionColumns = None,
+        bucketSpec = None,
+        parameters)
+      DataFrame(sqlContext, LogicalRelation(relation))
+    }
+
+    new FileStreamSource(sqlContext, metadataPath, path, schema, providerName, dataFrameBuilder)
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/b385ce38/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
index f45abbf..7e388ea 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
@@ -59,6 +59,8 @@ trait StreamTest extends QueryTest with Timeouts {
 
   implicit class RichSource(s: Source) {
     def toDF(): DataFrame = new DataFrame(sqlContext, StreamingRelation(s))
+
+    def toDS[A: Encoder](): Dataset[A] = new Dataset(sqlContext, StreamingRelation(s))
   }
 
   /** How long to wait for an active stream to catch up when checking a result. */

http://git-wip-us.apache.org/repos/asf/spark/blob/b385ce38/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
index 36212e4..b762f9b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala
@@ -33,8 +33,9 @@ object LastOptions {
 class DefaultSource extends StreamSourceProvider with StreamSinkProvider {
   override def createSource(
       sqlContext: SQLContext,
-      parameters: Map[String, String],
-      schema: Option[StructType]): Source = {
+      schema: Option[StructType],
+      providerName: String,
+      parameters: Map[String, String]): Source = {
     LastOptions.parameters = parameters
     LastOptions.schema = schema
     new Source {

http://git-wip-us.apache.org/repos/asf/spark/blob/b385ce38/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
new file mode 100644
index 0000000..7a4ee0e
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
@@ -0,0 +1,435 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import java.io.{ByteArrayInputStream, File, FileNotFoundException, InputStream}
+
+import com.google.common.base.Charsets.UTF_8
+
+import org.apache.spark.sql.StreamTest
+import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.FileStreamSource._
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{StringType, StructType}
+import org.apache.spark.util.Utils
+
+class FileStreamSourceTest extends StreamTest with SharedSQLContext {
+
+  import testImplicits._
+
+  case class AddTextFileData(source: FileStreamSource, content: String, src: File, tmp: File)
+    extends AddData {
+
+    override def addData(): Offset = {
+      source.withBatchingLocked {
+        val file = Utils.tempFileWith(new File(tmp, "text"))
+        stringToFile(file, content).renameTo(new File(src, file.getName))
+        source.currentOffset
+      } + 1
+    }
+  }
+
+  case class AddParquetFileData(
+      source: FileStreamSource,
+      content: Seq[String],
+      src: File,
+      tmp: File) extends AddData {
+
+    override def addData(): Offset = {
+      source.withBatchingLocked {
+        val file = Utils.tempFileWith(new File(tmp, "parquet"))
+        content.toDS().toDF().write.parquet(file.getCanonicalPath)
+        file.renameTo(new File(src, file.getName))
+        source.currentOffset
+      } + 1
+    }
+  }
+
+  /** Use `format` and `path` to create FileStreamSource via DataFrameReader */
+  def createFileStreamSource(
+      format: String,
+      path: String,
+      schema: Option[StructType] = None): FileStreamSource = {
+    val reader =
+      if (schema.isDefined) {
+        sqlContext.read.format(format).schema(schema.get)
+      } else {
+        sqlContext.read.format(format)
+      }
+    reader.stream(path)
+      .queryExecution.analyzed
+      .collect { case StreamingRelation(s: FileStreamSource, _) => s }
+      .head
+  }
+
+  val valueSchema = new StructType().add("value", StringType)
+}
+
+class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext {
+
+  import testImplicits._
+
+  private def createFileStreamSourceAndGetSchema(
+      format: Option[String],
+      path: Option[String],
+      schema: Option[StructType] = None): StructType = {
+    val reader = sqlContext.read
+    format.foreach(reader.format)
+    schema.foreach(reader.schema)
+    val df =
+      if (path.isDefined) {
+        reader.stream(path.get)
+      } else {
+        reader.stream()
+      }
+    df.queryExecution.analyzed
+      .collect { case StreamingRelation(s: FileStreamSource, _) => s }
+      .head
+      .schema
+  }
+
+  test("FileStreamSource schema: no path") {
+    val e = intercept[IllegalArgumentException] {
+      createFileStreamSourceAndGetSchema(format = None, path = None, schema = None)
+    }
+    assert("'path' is not specified" === e.getMessage)
+  }
+
+  test("FileStreamSource schema: path doesn't exist") {
+    intercept[FileNotFoundException] {
+      createFileStreamSourceAndGetSchema(format = None, path = Some("/a/b/c"), schema = None)
+    }
+  }
+
+  test("FileStreamSource schema: text, no existing files, no schema") {
+    withTempDir { src =>
+      val schema = createFileStreamSourceAndGetSchema(
+        format = Some("text"), path = Some(src.getCanonicalPath), schema = None)
+      assert(schema === new StructType().add("value", StringType))
+    }
+  }
+
+  test("FileStreamSource schema: text, existing files, no schema") {
+    withTempDir { src =>
+      stringToFile(new File(src, "1"), "a\nb\nc")
+      val schema = createFileStreamSourceAndGetSchema(
+        format = Some("text"), path = Some(src.getCanonicalPath), schema = None)
+      assert(schema === new StructType().add("value", StringType))
+    }
+  }
+
+  test("FileStreamSource schema: text, existing files, schema") {
+    withTempDir { src =>
+      stringToFile(new File(src, "1"), "a\nb\nc")
+      val userSchema = new StructType().add("userColumn", StringType)
+      val schema = createFileStreamSourceAndGetSchema(
+        format = Some("text"), path = Some(src.getCanonicalPath), schema = Some(userSchema))
+      assert(schema === userSchema)
+    }
+  }
+
+  test("FileStreamSource schema: parquet, no existing files, no schema") {
+    withTempDir { src =>
+      val e = intercept[IllegalArgumentException] {
+        createFileStreamSourceAndGetSchema(
+          format = Some("parquet"), path = Some(new File(src, "1").getCanonicalPath), schema = None)
+      }
+      assert("No schema specified" === e.getMessage)
+    }
+  }
+
+  test("FileStreamSource schema: parquet, existing files, no schema") {
+    withTempDir { src =>
+      Seq("a", "b", "c").toDS().as("userColumn").toDF()
+        .write.parquet(new File(src, "1").getCanonicalPath)
+      val schema = createFileStreamSourceAndGetSchema(
+        format = Some("parquet"), path = Some(src.getCanonicalPath), schema = None)
+      assert(schema === new StructType().add("value", StringType))
+    }
+  }
+
+  test("FileStreamSource schema: parquet, existing files, schema") {
+    withTempPath { src =>
+      Seq("a", "b", "c").toDS().as("oldUserColumn").toDF()
+        .write.parquet(new File(src, "1").getCanonicalPath)
+      val userSchema = new StructType().add("userColumn", StringType)
+      val schema = createFileStreamSourceAndGetSchema(
+        format = Some("parquet"), path = Some(src.getCanonicalPath), schema = Some(userSchema))
+      assert(schema === userSchema)
+    }
+  }
+
+  test("FileStreamSource schema: json, no existing files, no schema") {
+    withTempDir { src =>
+      val e = intercept[IllegalArgumentException] {
+        createFileStreamSourceAndGetSchema(
+          format = Some("json"), path = Some(src.getCanonicalPath), schema = None)
+      }
+      assert("No schema specified" === e.getMessage)
+    }
+  }
+
+  test("FileStreamSource schema: json, existing files, no schema") {
+    withTempDir { src =>
+      stringToFile(new File(src, "1"), "{'c': '1'}\n{'c': '2'}\n{'c': '3'}")
+      val schema = createFileStreamSourceAndGetSchema(
+        format = Some("json"), path = Some(src.getCanonicalPath), schema = None)
+      assert(schema === new StructType().add("c", StringType))
+    }
+  }
+
+  test("FileStreamSource schema: json, existing files, schema") {
+    withTempDir { src =>
+      stringToFile(new File(src, "1"), "{'c': '1'}\n{'c': '2'}\n{'c', '3'}")
+      val userSchema = new StructType().add("userColumn", StringType)
+      val schema = createFileStreamSourceAndGetSchema(
+        format = Some("json"), path = Some(src.getCanonicalPath), schema = Some(userSchema))
+      assert(schema === userSchema)
+    }
+  }
+
+  test("read from text files") {
+    val src = Utils.createTempDir("streaming.src")
+    val tmp = Utils.createTempDir("streaming.tmp")
+
+    val textSource = createFileStreamSource("text", src.getCanonicalPath)
+    val filtered = textSource.toDF().filter($"value" contains "keep")
+
+    testStream(filtered)(
+      AddTextFileData(textSource, "drop1\nkeep2\nkeep3", src, tmp),
+      CheckAnswer("keep2", "keep3"),
+      StopStream,
+      AddTextFileData(textSource, "drop4\nkeep5\nkeep6", src, tmp),
+      StartStream,
+      CheckAnswer("keep2", "keep3", "keep5", "keep6"),
+      AddTextFileData(textSource, "drop7\nkeep8\nkeep9", src, tmp),
+      CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9")
+    )
+
+    Utils.deleteRecursively(src)
+    Utils.deleteRecursively(tmp)
+  }
+
+  test("read from json files") {
+    val src = Utils.createTempDir("streaming.src")
+    val tmp = Utils.createTempDir("streaming.tmp")
+
+    val textSource = createFileStreamSource("json", src.getCanonicalPath, Some(valueSchema))
+    val filtered = textSource.toDF().filter($"value" contains "keep")
+
+    testStream(filtered)(
+      AddTextFileData(
+        textSource,
+        "{'value': 'drop1'}\n{'value': 'keep2'}\n{'value': 'keep3'}",
+        src,
+        tmp),
+      CheckAnswer("keep2", "keep3"),
+      StopStream,
+      AddTextFileData(
+        textSource,
+        "{'value': 'drop4'}\n{'value': 'keep5'}\n{'value': 'keep6'}",
+        src,
+        tmp),
+      StartStream,
+      CheckAnswer("keep2", "keep3", "keep5", "keep6"),
+      AddTextFileData(
+        textSource,
+        "{'value': 'drop7'}\n{'value': 'keep8'}\n{'value': 'keep9'}",
+        src,
+        tmp),
+      CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9")
+    )
+
+    Utils.deleteRecursively(src)
+    Utils.deleteRecursively(tmp)
+  }
+
+  test("read from json files with inferring schema") {
+    val src = Utils.createTempDir("streaming.src")
+    val tmp = Utils.createTempDir("streaming.tmp")
+
+    // Add a file so that we can infer its schema
+    stringToFile(new File(src, "existing"), "{'c': 'drop1'}\n{'c': 'keep2'}\n{'c': 'keep3'}")
+
+    val textSource = createFileStreamSource("json", src.getCanonicalPath)
+
+    // FileStreamSource should infer the column "c"
+    val filtered = textSource.toDF().filter($"c" contains "keep")
+
+    testStream(filtered)(
+      AddTextFileData(textSource, "{'c': 'drop4'}\n{'c': 'keep5'}\n{'c': 'keep6'}", src, tmp),
+      CheckAnswer("keep2", "keep3", "keep5", "keep6")
+    )
+
+    Utils.deleteRecursively(src)
+    Utils.deleteRecursively(tmp)
+  }
+
+  test("read from parquet files") {
+    val src = Utils.createTempDir("streaming.src")
+    val tmp = Utils.createTempDir("streaming.tmp")
+
+    val fileSource = createFileStreamSource("parquet", src.getCanonicalPath, Some(valueSchema))
+    val filtered = fileSource.toDF().filter($"value" contains "keep")
+
+    testStream(filtered)(
+      AddParquetFileData(fileSource, Seq("drop1", "keep2", "keep3"), src, tmp),
+      CheckAnswer("keep2", "keep3"),
+      StopStream,
+      AddParquetFileData(fileSource, Seq("drop4", "keep5", "keep6"), src, tmp),
+      StartStream,
+      CheckAnswer("keep2", "keep3", "keep5", "keep6"),
+      AddParquetFileData(fileSource, Seq("drop7", "keep8", "keep9"), src, tmp),
+      CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9")
+    )
+
+    Utils.deleteRecursively(src)
+    Utils.deleteRecursively(tmp)
+  }
+
+  test("file stream source without schema") {
+    val src = Utils.createTempDir("streaming.src")
+
+    // Only "text" doesn't need a schema
+    createFileStreamSource("text", src.getCanonicalPath)
+
+    // Both "json" and "parquet" require a schema if no existing file to infer
+    intercept[IllegalArgumentException] {
+      createFileStreamSource("json", src.getCanonicalPath)
+    }
+    intercept[IllegalArgumentException] {
+      createFileStreamSource("parquet", src.getCanonicalPath)
+    }
+
+    Utils.deleteRecursively(src)
+  }
+
+  test("fault tolerance") {
+    def assertBatch(batch1: Option[Batch], batch2: Option[Batch]): Unit = {
+      (batch1, batch2) match {
+        case (Some(b1), Some(b2)) =>
+          assert(b1.end === b2.end)
+          assert(b1.data.as[String].collect() === b2.data.as[String].collect())
+        case (None, None) =>
+        case _ => fail(s"batch ($batch1) is not equal to batch ($batch2)")
+      }
+    }
+
+    val src = Utils.createTempDir("streaming.src")
+    val tmp = Utils.createTempDir("streaming.tmp")
+
+    val textSource = createFileStreamSource("text", src.getCanonicalPath)
+    val filtered = textSource.toDF().filter($"value" contains "keep")
+
+    testStream(filtered)(
+      AddTextFileData(textSource, "drop1\nkeep2\nkeep3", src, tmp),
+      CheckAnswer("keep2", "keep3"),
+      StopStream,
+      AddTextFileData(textSource, "drop4\nkeep5\nkeep6", src, tmp),
+      StartStream,
+      CheckAnswer("keep2", "keep3", "keep5", "keep6"),
+      AddTextFileData(textSource, "drop7\nkeep8\nkeep9", src, tmp),
+      CheckAnswer("keep2", "keep3", "keep5", "keep6", "keep8", "keep9")
+    )
+
+    val textSource2 = createFileStreamSource("text", src.getCanonicalPath)
+    assert(textSource2.currentOffset === textSource.currentOffset)
+    assertBatch(textSource2.getNextBatch(None), textSource.getNextBatch(None))
+    for (f <- 0L to textSource.currentOffset.offset) {
+      val offset = LongOffset(f)
+      assertBatch(textSource2.getNextBatch(Some(offset)), textSource.getNextBatch(Some(offset)))
+    }
+
+    Utils.deleteRecursively(src)
+    Utils.deleteRecursively(tmp)
+  }
+
+  test("fault tolerance with corrupted metadata file") {
+    val src = Utils.createTempDir("streaming.src")
+    assert(new File(src, "_metadata").mkdirs())
+    stringToFile(
+      new File(src, "_metadata/0"),
+      s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c\n-/e/f/g\nEND\n")
+    stringToFile(new File(src, "_metadata/1"), s"${FileStreamSource.VERSION}\nSTART\n-")
+
+    val textSource = createFileStreamSource("text", src.getCanonicalPath)
+    // the metadata file of batch is corrupted, so currentOffset should be 0
+    assert(textSource.currentOffset === LongOffset(0))
+
+    Utils.deleteRecursively(src)
+  }
+
+  test("fault tolerance with normal metadata file") {
+    val src = Utils.createTempDir("streaming.src")
+    assert(new File(src, "_metadata").mkdirs())
+    stringToFile(
+      new File(src, "_metadata/0"),
+      s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c\n-/e/f/g\nEND\n")
+    stringToFile(
+      new File(src, "_metadata/1"),
+      s"${FileStreamSource.VERSION}\nSTART\n-/x/y/z\nEND\n")
+
+    val textSource = createFileStreamSource("text", src.getCanonicalPath)
+    assert(textSource.currentOffset === LongOffset(1))
+
+    Utils.deleteRecursively(src)
+  }
+
+  test("readBatch") {
+    def stringToStream(str: String): InputStream = new ByteArrayInputStream(str.getBytes(UTF_8))
+
+    // Invalid metadata
+    assert(readBatch(stringToStream("")) === Nil)
+    assert(readBatch(stringToStream(FileStreamSource.VERSION)) === Nil)
+    assert(readBatch(stringToStream(s"${FileStreamSource.VERSION}\n")) === Nil)
+    assert(readBatch(stringToStream(s"${FileStreamSource.VERSION}\nSTART")) === Nil)
+    assert(readBatch(stringToStream(s"${FileStreamSource.VERSION}\nSTART\n-")) === Nil)
+    assert(readBatch(stringToStream(s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c")) === Nil)
+    assert(readBatch(stringToStream(s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c\n")) === Nil)
+    assert(readBatch(stringToStream(s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c\nEN")) === Nil)
+
+    // Valid metadata
+    assert(readBatch(stringToStream(
+      s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c\nEND")) === Seq("/a/b/c"))
+    assert(readBatch(stringToStream(
+      s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c\nEND\n")) === Seq("/a/b/c"))
+    assert(readBatch(stringToStream(
+      s"${FileStreamSource.VERSION}\nSTART\n-/a/b/c\n-/e/f/g\nEND\n"))
+      === Seq("/a/b/c", "/e/f/g"))
+  }
+}
+
+class FileStreamSourceStressTestSuite extends FileStreamSourceTest with SharedSQLContext {
+
+  import testImplicits._
+
+  test("file source stress test") {
+    val src = Utils.createTempDir("streaming.src")
+    val tmp = Utils.createTempDir("streaming.tmp")
+
+    val textSource = createFileStreamSource("text", src.getCanonicalPath)
+    val ds = textSource.toDS[String]().map(_.toInt + 1)
+    runStressTest(ds, data => {
+      AddTextFileData(textSource, data.mkString("\n"), src, tmp)
+    })
+
+    Utils.deleteRecursively(src)
+    Utils.deleteRecursively(tmp)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org