You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by td...@apache.org on 2018/05/04 21:14:46 UTC

spark git commit: [SPARK-24039][SS] Do continuous processing writes with multiple compute() calls

Repository: spark
Updated Branches:
  refs/heads/master d04806a23 -> af4dc5028


[SPARK-24039][SS] Do continuous processing writes with multiple compute() calls

## What changes were proposed in this pull request?

Do continuous processing writes with multiple compute() calls.

The current strategy (before this PR) is hacky; we just call next() on an iterator which has already returned hasNext = false, knowing that all the nodes we whitelist handle this properly. This will have to be changed before we can support more complex query plans. (In particular, I have a WIP https://github.com/jose-torres/spark/pull/13 which should be able to support aggregates in a single partition with minimal additional work.)

Most of the changes here are just refactoring to accommodate the new model. The behavioral changes are:

* The writer now calls prev.compute(split, context) once per epoch within the epoch loop.
* ContinuousDataSourceRDD now spawns a ContinuousQueuedDataReader which is shared across multiple calls to compute() for the same partition.

## How was this patch tested?

existing unit tests

Author: Jose Torres <to...@gmail.com>

Closes #21200 from jose-torres/noAggr.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/af4dc502
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/af4dc502
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/af4dc502

Branch: refs/heads/master
Commit: af4dc50280ffcdeda208ef2dc5f8b843389732e5
Parents: d04806a
Author: Jose Torres <to...@gmail.com>
Authored: Fri May 4 14:14:40 2018 -0700
Committer: Tathagata Das <ta...@gmail.com>
Committed: Fri May 4 14:14:40 2018 -0700

----------------------------------------------------------------------
 .../datasources/v2/DataSourceV2ScanExec.scala   |   6 +-
 .../continuous/ContinuousDataSourceRDD.scala    | 114 ++++++++++
 .../ContinuousDataSourceRDDIter.scala           | 222 -------------------
 .../continuous/ContinuousQueuedDataReader.scala | 211 ++++++++++++++++++
 .../continuous/ContinuousWriteRDD.scala         |  90 ++++++++
 .../WriteToContinuousDataSourceExec.scala       |  57 +----
 .../ContinuousQueuedDataReaderSuite.scala       | 167 ++++++++++++++
 7 files changed, 592 insertions(+), 275 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/af4dc502/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
index 41bdda4..77cb707 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
@@ -96,7 +96,11 @@ case class DataSourceV2ScanExec(
           sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
           sparkContext.env)
         .askSync[Unit](SetReaderPartitions(readerFactories.size))
-      new ContinuousDataSourceRDD(sparkContext, sqlContext, readerFactories)
+      new ContinuousDataSourceRDD(
+        sparkContext,
+        sqlContext.conf.continuousStreamingExecutorQueueSize,
+        sqlContext.conf.continuousStreamingExecutorPollIntervalMs,
+        readerFactories)
         .asInstanceOf[RDD[InternalRow]]
 
     case r: SupportsScanColumnarBatch if r.enableBatchRead() =>

http://git-wip-us.apache.org/repos/asf/spark/blob/af4dc502/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala
new file mode 100644
index 0000000..0a3b9dc
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous
+
+import org.apache.spark._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader}
+import org.apache.spark.sql.sources.v2.reader._
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, PartitionOffset}
+import org.apache.spark.util.{NextIterator, ThreadUtils}
+
+class ContinuousDataSourceRDDPartition(
+    val index: Int,
+    val readerFactory: DataReaderFactory[UnsafeRow])
+  extends Partition with Serializable {
+
+  // This is semantically a lazy val - it's initialized once the first time a call to
+  // ContinuousDataSourceRDD.compute() needs to access it, so it can be shared across
+  // all compute() calls for a partition. This ensures that one compute() picks up where the
+  // previous one ended.
+  // We don't make it actually a lazy val because it needs input which isn't available here.
+  // This will only be initialized on the executors.
+  private[continuous] var queueReader: ContinuousQueuedDataReader = _
+}
+
+/**
+ * The bottom-most RDD of a continuous processing read task. Wraps a [[ContinuousQueuedDataReader]]
+ * to read from the remote source, and polls that queue for incoming rows.
+ *
+ * Note that continuous processing calls compute() multiple times, and the same
+ * [[ContinuousQueuedDataReader]] instance will/must be shared between each call for the same split.
+ */
+class ContinuousDataSourceRDD(
+    sc: SparkContext,
+    dataQueueSize: Int,
+    epochPollIntervalMs: Long,
+    @transient private val readerFactories: Seq[DataReaderFactory[UnsafeRow]])
+  extends RDD[UnsafeRow](sc, Nil) {
+
+  override protected def getPartitions: Array[Partition] = {
+    readerFactories.zipWithIndex.map {
+      case (readerFactory, index) => new ContinuousDataSourceRDDPartition(index, readerFactory)
+    }.toArray
+  }
+
+  /**
+   * Initialize the shared reader for this partition if needed, then read rows from it until
+   * it returns null to signal the end of the epoch.
+   */
+  override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = {
+    // If attempt number isn't 0, this is a task retry, which we don't support.
+    if (context.attemptNumber() != 0) {
+      throw new ContinuousTaskRetryException()
+    }
+
+    val readerForPartition = {
+      val partition = split.asInstanceOf[ContinuousDataSourceRDDPartition]
+      if (partition.queueReader == null) {
+        partition.queueReader =
+          new ContinuousQueuedDataReader(
+            partition.readerFactory, context, dataQueueSize, epochPollIntervalMs)
+      }
+
+      partition.queueReader
+    }
+
+    new NextIterator[UnsafeRow] {
+      override def getNext(): UnsafeRow = {
+        readerForPartition.next() match {
+          case null =>
+            finished = true
+            null
+          case row => row
+        }
+      }
+
+      override def close(): Unit = {}
+    }
+  }
+
+  override def getPreferredLocations(split: Partition): Seq[String] = {
+    split.asInstanceOf[ContinuousDataSourceRDDPartition].readerFactory.preferredLocations()
+  }
+}
+
+object ContinuousDataSourceRDD {
+  private[continuous] def getContinuousReader(
+      reader: DataReader[UnsafeRow]): ContinuousDataReader[_] = {
+    reader match {
+      case r: ContinuousDataReader[UnsafeRow] => r
+      case wrapped: RowToUnsafeDataReader =>
+        wrapped.rowReader.asInstanceOf[ContinuousDataReader[Row]]
+      case _ =>
+        throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}")
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/af4dc502/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
deleted file mode 100644
index 06754f0..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
+++ /dev/null
@@ -1,222 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming.continuous
-
-import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit}
-import java.util.concurrent.atomic.AtomicBoolean
-
-import scala.collection.JavaConverters._
-
-import org.apache.spark._
-import org.apache.spark.internal.Logging
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Row, SQLContext}
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader}
-import org.apache.spark.sql.sources.v2.reader._
-import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, PartitionOffset}
-import org.apache.spark.util.ThreadUtils
-
-class ContinuousDataSourceRDD(
-    sc: SparkContext,
-    sqlContext: SQLContext,
-    @transient private val readerFactories: Seq[DataReaderFactory[UnsafeRow]])
-  extends RDD[UnsafeRow](sc, Nil) {
-
-  private val dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize
-  private val epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs
-
-  override protected def getPartitions: Array[Partition] = {
-    readerFactories.zipWithIndex.map {
-      case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory)
-    }.toArray
-  }
-
-  override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = {
-    // If attempt number isn't 0, this is a task retry, which we don't support.
-    if (context.attemptNumber() != 0) {
-      throw new ContinuousTaskRetryException()
-    }
-
-    val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]]
-      .readerFactory.createDataReader()
-
-    val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)
-
-    // This queue contains two types of messages:
-    // * (null, null) representing an epoch boundary.
-    // * (row, off) containing a data row and its corresponding PartitionOffset.
-    val queue = new ArrayBlockingQueue[(UnsafeRow, PartitionOffset)](dataQueueSize)
-
-    val epochPollFailed = new AtomicBoolean(false)
-    val epochPollExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor(
-      s"epoch-poll--$coordinatorId--${context.partitionId()}")
-    val epochPollRunnable = new EpochPollRunnable(queue, context, epochPollFailed)
-    epochPollExecutor.scheduleWithFixedDelay(
-      epochPollRunnable, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS)
-
-    // Important sequencing - we must get start offset before the data reader thread begins
-    val startOffset = ContinuousDataSourceRDD.getBaseReader(reader).getOffset
-
-    val dataReaderFailed = new AtomicBoolean(false)
-    val dataReaderThread = new DataReaderThread(reader, queue, context, dataReaderFailed)
-    dataReaderThread.setDaemon(true)
-    dataReaderThread.start()
-
-    context.addTaskCompletionListener(_ => {
-      dataReaderThread.interrupt()
-      epochPollExecutor.shutdown()
-    })
-
-    val epochEndpoint = EpochCoordinatorRef.get(coordinatorId, SparkEnv.get)
-    new Iterator[UnsafeRow] {
-      private val POLL_TIMEOUT_MS = 1000
-
-      private var currentEntry: (UnsafeRow, PartitionOffset) = _
-      private var currentOffset: PartitionOffset = startOffset
-      private var currentEpoch =
-        context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
-
-      override def hasNext(): Boolean = {
-        while (currentEntry == null) {
-          if (context.isInterrupted() || context.isCompleted()) {
-            currentEntry = (null, null)
-          }
-          if (dataReaderFailed.get()) {
-            throw new SparkException("data read failed", dataReaderThread.failureReason)
-          }
-          if (epochPollFailed.get()) {
-            throw new SparkException("epoch poll failed", epochPollRunnable.failureReason)
-          }
-          currentEntry = queue.poll(POLL_TIMEOUT_MS, TimeUnit.MILLISECONDS)
-        }
-
-        currentEntry match {
-          // epoch boundary marker
-          case (null, null) =>
-            epochEndpoint.send(ReportPartitionOffset(
-              context.partitionId(),
-              currentEpoch,
-              currentOffset))
-            currentEpoch += 1
-            currentEntry = null
-            false
-          // real row
-          case (_, offset) =>
-            currentOffset = offset
-            true
-        }
-      }
-
-      override def next(): UnsafeRow = {
-        if (currentEntry == null) throw new NoSuchElementException("No current row was set")
-        val r = currentEntry._1
-        currentEntry = null
-        r
-      }
-    }
-  }
-
-  override def getPreferredLocations(split: Partition): Seq[String] = {
-    split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readerFactory.preferredLocations()
-  }
-}
-
-case class EpochPackedPartitionOffset(epoch: Long) extends PartitionOffset
-
-class EpochPollRunnable(
-    queue: BlockingQueue[(UnsafeRow, PartitionOffset)],
-    context: TaskContext,
-    failedFlag: AtomicBoolean)
-  extends Thread with Logging {
-  private[continuous] var failureReason: Throwable = _
-
-  private val epochEndpoint = EpochCoordinatorRef.get(
-    context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get)
-  private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
-
-  override def run(): Unit = {
-    try {
-      val newEpoch = epochEndpoint.askSync[Long](GetCurrentEpoch)
-      for (i <- currentEpoch to newEpoch - 1) {
-        queue.put((null, null))
-        logDebug(s"Sent marker to start epoch ${i + 1}")
-      }
-      currentEpoch = newEpoch
-    } catch {
-      case t: Throwable =>
-        failureReason = t
-        failedFlag.set(true)
-        throw t
-    }
-  }
-}
-
-class DataReaderThread(
-    reader: DataReader[UnsafeRow],
-    queue: BlockingQueue[(UnsafeRow, PartitionOffset)],
-    context: TaskContext,
-    failedFlag: AtomicBoolean)
-  extends Thread(
-    s"continuous-reader--${context.partitionId()}--" +
-    s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") {
-  private[continuous] var failureReason: Throwable = _
-
-  override def run(): Unit = {
-    TaskContext.setTaskContext(context)
-    val baseReader = ContinuousDataSourceRDD.getBaseReader(reader)
-    try {
-      while (!context.isInterrupted && !context.isCompleted()) {
-        if (!reader.next()) {
-          // Check again, since reader.next() might have blocked through an incoming interrupt.
-          if (!context.isInterrupted && !context.isCompleted()) {
-            throw new IllegalStateException(
-              "Continuous reader reported no elements! Reader should have blocked waiting.")
-          } else {
-            return
-          }
-        }
-
-        queue.put((reader.get().copy(), baseReader.getOffset))
-      }
-    } catch {
-      case _: InterruptedException if context.isInterrupted() =>
-        // Continuous shutdown always involves an interrupt; do nothing and shut down quietly.
-
-      case t: Throwable =>
-        failureReason = t
-        failedFlag.set(true)
-        // Don't rethrow the exception in this thread. It's not needed, and the default Spark
-        // exception handler will kill the executor.
-    } finally {
-      reader.close()
-    }
-  }
-}
-
-object ContinuousDataSourceRDD {
-  private[continuous] def getBaseReader(reader: DataReader[UnsafeRow]): ContinuousDataReader[_] = {
-    reader match {
-      case r: ContinuousDataReader[UnsafeRow] => r
-      case wrapped: RowToUnsafeDataReader =>
-        wrapped.rowReader.asInstanceOf[ContinuousDataReader[Row]]
-      case _ =>
-        throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}")
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/af4dc502/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala
new file mode 100644
index 0000000..01a999f
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous
+
+import java.io.Closeable
+import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit}
+import java.util.concurrent.atomic.AtomicBoolean
+
+import scala.util.control.NonFatal
+
+import org.apache.spark.{Partition, SparkEnv, SparkException, TaskContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory}
+import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset
+import org.apache.spark.util.ThreadUtils
+
+/**
+ * A wrapper for a continuous processing data reader, including a reading queue and epoch markers.
+ *
+ * This will be instantiated once per partition - successive calls to compute() in the
+ * [[ContinuousDataSourceRDD]] will reuse the same reader. This is required to get continuity of
+ * offsets across epochs. Each compute() should call the next() method here until null is returned.
+ */
+class ContinuousQueuedDataReader(
+    factory: DataReaderFactory[UnsafeRow],
+    context: TaskContext,
+    dataQueueSize: Int,
+    epochPollIntervalMs: Long) extends Closeable {
+  private val reader = factory.createDataReader()
+
+  // Important sequencing - we must get our starting point before the provider threads start running
+  private var currentOffset: PartitionOffset =
+    ContinuousDataSourceRDD.getContinuousReader(reader).getOffset
+  private var currentEpoch: Long =
+    context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
+
+  /**
+   * The record types in the read buffer.
+   */
+  sealed trait ContinuousRecord
+  case object EpochMarker extends ContinuousRecord
+  case class ContinuousRow(row: UnsafeRow, offset: PartitionOffset) extends ContinuousRecord
+
+  private val queue = new ArrayBlockingQueue[ContinuousRecord](dataQueueSize)
+
+  private val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)
+  private val epochCoordEndpoint = EpochCoordinatorRef.get(
+    context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get)
+
+  private val epochMarkerExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor(
+    s"epoch-poll--$coordinatorId--${context.partitionId()}")
+  private val epochMarkerGenerator = new EpochMarkerGenerator
+  epochMarkerExecutor.scheduleWithFixedDelay(
+    epochMarkerGenerator, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS)
+
+  private val dataReaderThread = new DataReaderThread
+  dataReaderThread.setDaemon(true)
+  dataReaderThread.start()
+
+  context.addTaskCompletionListener(_ => {
+    this.close()
+  })
+
+  private def shouldStop() = {
+    context.isInterrupted() || context.isCompleted()
+  }
+
+  /**
+   * Return the next UnsafeRow to be read in the current epoch, or null if the epoch is done.
+   *
+   * After returning null, the [[ContinuousDataSourceRDD]] compute() for the following epoch
+   * will call next() again to start getting rows.
+   */
+  def next(): UnsafeRow = {
+    val POLL_TIMEOUT_MS = 1000
+    var currentEntry: ContinuousRecord = null
+
+    while (currentEntry == null) {
+      if (shouldStop()) {
+        // Force the epoch to end here. The writer will notice the context is interrupted
+        // or completed and not start a new one. This makes it possible to achieve clean
+        // shutdown of the streaming query.
+        // TODO: The obvious generalization of this logic to multiple stages won't work. It's
+        // invalid to send an epoch marker from the bottom of a task if all its child tasks
+        // haven't sent one.
+        currentEntry = EpochMarker
+      } else {
+        if (dataReaderThread.failureReason != null) {
+          throw new SparkException("Data read failed", dataReaderThread.failureReason)
+        }
+        if (epochMarkerGenerator.failureReason != null) {
+          throw new SparkException(
+            "Epoch marker generation failed",
+            epochMarkerGenerator.failureReason)
+        }
+        currentEntry = queue.poll(POLL_TIMEOUT_MS, TimeUnit.MILLISECONDS)
+      }
+    }
+
+    currentEntry match {
+      case EpochMarker =>
+        epochCoordEndpoint.send(ReportPartitionOffset(
+          context.partitionId(), currentEpoch, currentOffset))
+        currentEpoch += 1
+        null
+      case ContinuousRow(row, offset) =>
+        currentOffset = offset
+        row
+    }
+  }
+
+  override def close(): Unit = {
+    dataReaderThread.interrupt()
+    epochMarkerExecutor.shutdown()
+  }
+
+  /**
+   * The data component of [[ContinuousQueuedDataReader]]. Pushes (row, offset) to the queue when
+   * a new row arrives to the [[DataReader]].
+   */
+  class DataReaderThread extends Thread(
+      s"continuous-reader--${context.partitionId()}--" +
+        s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") with Logging {
+    @volatile private[continuous] var failureReason: Throwable = _
+
+    override def run(): Unit = {
+      TaskContext.setTaskContext(context)
+      val baseReader = ContinuousDataSourceRDD.getContinuousReader(reader)
+      try {
+        while (!shouldStop()) {
+          if (!reader.next()) {
+            // Check again, since reader.next() might have blocked through an incoming interrupt.
+            if (!shouldStop()) {
+              throw new IllegalStateException(
+                "Continuous reader reported no elements! Reader should have blocked waiting.")
+            } else {
+              return
+            }
+          }
+
+          queue.put(ContinuousRow(reader.get().copy(), baseReader.getOffset))
+        }
+      } catch {
+        case _: InterruptedException =>
+          // Continuous shutdown always involves an interrupt; do nothing and shut down quietly.
+          logInfo(s"shutting down interrupted data reader thread $getName")
+
+        case NonFatal(t) =>
+          failureReason = t
+          logWarning("data reader thread failed", t)
+          // If we throw from this thread, we may kill the executor. Let the parent thread handle
+          // it.
+
+        case t: Throwable =>
+          failureReason = t
+          throw t
+      } finally {
+        reader.close()
+      }
+    }
+  }
+
+  /**
+   * The epoch marker component of [[ContinuousQueuedDataReader]]. Populates the queue with
+   * EpochMarker when a new epoch marker arrives.
+   */
+  class EpochMarkerGenerator extends Runnable with Logging {
+    @volatile private[continuous] var failureReason: Throwable = _
+
+    private val epochCoordEndpoint = EpochCoordinatorRef.get(
+      context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get)
+    // Note that this is *not* the same as the currentEpoch in [[ContinuousDataQueuedReader]]! That
+    // field represents the epoch wrt the data being processed. The currentEpoch here is just a
+    // counter to ensure we send the appropriate number of markers if we fall behind the driver.
+    private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
+
+    override def run(): Unit = {
+      try {
+        val newEpoch = epochCoordEndpoint.askSync[Long](GetCurrentEpoch)
+        // It's possible to fall more than 1 epoch behind if a GetCurrentEpoch RPC ends up taking
+        // a while. We catch up by injecting enough epoch markers immediately to catch up. This will
+        // result in some epochs being empty for this partition, but that's fine.
+        for (i <- currentEpoch to newEpoch - 1) {
+          queue.put(EpochMarker)
+          logDebug(s"Sent marker to start epoch ${i + 1}")
+        }
+        currentEpoch = newEpoch
+      } catch {
+        case t: Throwable =>
+          failureReason = t
+          throw t
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/af4dc502/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala
new file mode 100644
index 0000000..91f1576
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous
+
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.{Partition, SparkEnv, TaskContext}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo}
+import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage}
+import org.apache.spark.util.Utils
+
+/**
+ * The RDD writing to a sink in continuous processing.
+ *
+ * Within each task, we repeatedly call prev.compute(). Each resulting iterator contains the data
+ * to be written for one epoch, which we commit and forward to the driver.
+ *
+ * We keep repeating prev.compute() and writing new epochs until the query is shut down.
+ */
+class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactory[InternalRow])
+    extends RDD[Unit](prev) {
+
+  override val partitioner = prev.partitioner
+
+  override def getPartitions: Array[Partition] = prev.partitions
+
+  override def compute(split: Partition, context: TaskContext): Iterator[Unit] = {
+    val epochCoordinator = EpochCoordinatorRef.get(
+      context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
+      SparkEnv.get)
+    var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
+
+    while (!context.isInterrupted() && !context.isCompleted()) {
+      var dataWriter: DataWriter[InternalRow] = null
+      // write the data and commit this writer.
+      Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
+        try {
+          val dataIterator = prev.compute(split, context)
+          dataWriter = writeTask.createDataWriter(
+            context.partitionId(), context.attemptNumber(), currentEpoch)
+          while (dataIterator.hasNext) {
+            dataWriter.write(dataIterator.next())
+          }
+          logInfo(s"Writer for partition ${context.partitionId()} " +
+            s"in epoch $currentEpoch is committing.")
+          val msg = dataWriter.commit()
+          epochCoordinator.send(
+            CommitPartitionEpoch(context.partitionId(), currentEpoch, msg)
+          )
+          logInfo(s"Writer for partition ${context.partitionId()} " +
+            s"in epoch $currentEpoch committed.")
+          currentEpoch += 1
+        } catch {
+          case _: InterruptedException =>
+          // Continuous shutdown always involves an interrupt. Just finish the task.
+        }
+      })(catchBlock = {
+        // If there is an error, abort this writer. We enter this callback in the middle of
+        // rethrowing an exception, so compute() will stop executing at this point.
+        logError(s"Writer for partition ${context.partitionId()} is aborting.")
+        if (dataWriter != null) dataWriter.abort()
+        logError(s"Writer for partition ${context.partitionId()} aborted.")
+      })
+    }
+
+    Iterator()
+  }
+
+  override def clearDependencies() {
+    super.clearDependencies()
+    prev = null
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/af4dc502/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala
index ba88ae1..e0af3a2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala
@@ -46,24 +46,19 @@ case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPla
       case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema)
     }
 
-    val rdd = query.execute()
+    val rdd = new ContinuousWriteRDD(query.execute(), writerFactory)
 
     logInfo(s"Start processing data source writer: $writer. " +
-      s"The input RDD has ${rdd.getNumPartitions} partitions.")
-    // Let the epoch coordinator know how many partitions the write RDD has.
+      s"The input RDD has ${rdd.partitions.length} partitions.")
     EpochCoordinatorRef.get(
-        sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
-        sparkContext.env)
+      sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
+      sparkContext.env)
       .askSync[Unit](SetWriterPartitions(rdd.getNumPartitions))
 
     try {
       // Force the RDD to run so continuous processing starts; no data is actually being collected
       // to the driver, as ContinuousWriteRDD outputs nothing.
-      sparkContext.runJob(
-        rdd,
-        (context: TaskContext, iter: Iterator[InternalRow]) =>
-          WriteToContinuousDataSourceExec.run(writerFactory, context, iter),
-        rdd.partitions.indices)
+      rdd.collect()
     } catch {
       case _: InterruptedException =>
         // Interruption is how continuous queries are ended, so accept and ignore the exception.
@@ -80,45 +75,3 @@ case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPla
     sparkContext.emptyRDD
   }
 }
-
-object WriteToContinuousDataSourceExec extends Logging {
-  def run(
-      writeTask: DataWriterFactory[InternalRow],
-      context: TaskContext,
-      iter: Iterator[InternalRow]): Unit = {
-    val epochCoordinator = EpochCoordinatorRef.get(
-      context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
-      SparkEnv.get)
-    var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
-
-    do {
-      var dataWriter: DataWriter[InternalRow] = null
-      // write the data and commit this writer.
-      Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
-        try {
-          dataWriter = writeTask.createDataWriter(
-            context.partitionId(), context.attemptNumber(), currentEpoch)
-          while (iter.hasNext) {
-            dataWriter.write(iter.next())
-          }
-          logInfo(s"Writer for partition ${context.partitionId()} is committing.")
-          val msg = dataWriter.commit()
-          logInfo(s"Writer for partition ${context.partitionId()} committed.")
-          epochCoordinator.send(
-            CommitPartitionEpoch(context.partitionId(), currentEpoch, msg)
-          )
-          currentEpoch += 1
-        } catch {
-          case _: InterruptedException =>
-          // Continuous shutdown always involves an interrupt. Just finish the task.
-        }
-      })(catchBlock = {
-        // If there is an error, abort this writer. We enter this callback in the middle of
-        // rethrowing an exception, so runContinuous will stop executing at this point.
-        logError(s"Writer for partition ${context.partitionId()} is aborting.")
-        if (dataWriter != null) dataWriter.abort()
-        logError(s"Writer for partition ${context.partitionId()} aborted.")
-      })
-    } while (!context.isInterrupted())
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/af4dc502/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala
new file mode 100644
index 0000000..e755625
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming.continuous
+
+import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue}
+
+import org.mockito.{ArgumentCaptor, Matchers}
+import org.mockito.Mockito._
+import org.scalatest.mockito.MockitoSugar
+
+import org.apache.spark.{SparkEnv, SparkFunSuite, TaskContext}
+import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv}
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.execution.streaming.continuous._
+import org.apache.spark.sql.sources.v2.reader.DataReaderFactory
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, PartitionOffset}
+import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
+import org.apache.spark.sql.streaming.StreamTest
+import org.apache.spark.sql.types.{DataType, IntegerType}
+
+class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar {
+  case class LongPartitionOffset(offset: Long) extends PartitionOffset
+
+  val coordinatorId = s"${getClass.getSimpleName}-epochCoordinatorIdForUnitTest"
+  val startEpoch = 0
+
+  var epochEndpoint: RpcEndpointRef = _
+
+  override def beforeEach(): Unit = {
+    super.beforeEach()
+    epochEndpoint = EpochCoordinatorRef.create(
+      mock[StreamWriter],
+      mock[ContinuousReader],
+      mock[ContinuousExecution],
+      coordinatorId,
+      startEpoch,
+      spark,
+      SparkEnv.get)
+  }
+
+  override def afterEach(): Unit = {
+    SparkEnv.get.rpcEnv.stop(epochEndpoint)
+    epochEndpoint = null
+    super.afterEach()
+  }
+
+
+  private val mockContext = mock[TaskContext]
+  when(mockContext.getLocalProperty(ContinuousExecution.START_EPOCH_KEY))
+    .thenReturn(startEpoch.toString)
+  when(mockContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY))
+    .thenReturn(coordinatorId)
+
+  /**
+   * Set up a ContinuousQueuedDataReader for testing. The blocking queue can be used to send
+   * rows to the wrapped data reader.
+   */
+  private def setup(): (BlockingQueue[UnsafeRow], ContinuousQueuedDataReader) = {
+    val queue = new ArrayBlockingQueue[UnsafeRow](1024)
+    val factory = new DataReaderFactory[UnsafeRow] {
+      override def createDataReader() = new ContinuousDataReader[UnsafeRow] {
+        var index = -1
+        var curr: UnsafeRow = _
+
+        override def next() = {
+          curr = queue.take()
+          index += 1
+          true
+        }
+
+        override def get = curr
+
+        override def getOffset = LongPartitionOffset(index)
+
+        override def close() = {}
+      }
+    }
+    val reader = new ContinuousQueuedDataReader(
+      factory,
+      mockContext,
+      dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize,
+      epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs)
+
+    (queue, reader)
+  }
+
+  private def unsafeRow(value: Int) = {
+    UnsafeProjection.create(Array(IntegerType : DataType))(
+      new GenericInternalRow(Array(value: Any)))
+  }
+
+  test("basic data read") {
+    val (input, reader) = setup()
+
+    input.add(unsafeRow(12345))
+    assert(reader.next().getInt(0) == 12345)
+  }
+
+  test("basic epoch marker") {
+    val (input, reader) = setup()
+
+    epochEndpoint.askSync[Long](IncrementAndGetEpoch)
+    assert(reader.next() == null)
+  }
+
+  test("new rows after markers") {
+    val (input, reader) = setup()
+
+    epochEndpoint.askSync[Long](IncrementAndGetEpoch)
+    epochEndpoint.askSync[Long](IncrementAndGetEpoch)
+    epochEndpoint.askSync[Long](IncrementAndGetEpoch)
+    assert(reader.next() == null)
+    assert(reader.next() == null)
+    assert(reader.next() == null)
+    input.add(unsafeRow(11111))
+    input.add(unsafeRow(22222))
+    assert(reader.next().getInt(0) == 11111)
+    assert(reader.next().getInt(0) == 22222)
+  }
+
+  test("new markers after rows") {
+    val (input, reader) = setup()
+
+    input.add(unsafeRow(11111))
+    input.add(unsafeRow(22222))
+    assert(reader.next().getInt(0) == 11111)
+    assert(reader.next().getInt(0) == 22222)
+    epochEndpoint.askSync[Long](IncrementAndGetEpoch)
+    epochEndpoint.askSync[Long](IncrementAndGetEpoch)
+    epochEndpoint.askSync[Long](IncrementAndGetEpoch)
+    assert(reader.next() == null)
+    assert(reader.next() == null)
+    assert(reader.next() == null)
+  }
+
+  test("alternating markers and rows") {
+    val (input, reader) = setup()
+
+    input.add(unsafeRow(11111))
+    assert(reader.next().getInt(0) == 11111)
+    input.add(unsafeRow(22222))
+    assert(reader.next().getInt(0) == 22222)
+    epochEndpoint.askSync[Long](IncrementAndGetEpoch)
+    assert(reader.next() == null)
+    input.add(unsafeRow(33333))
+    assert(reader.next().getInt(0) == 33333)
+    input.add(unsafeRow(44444))
+    assert(reader.next().getInt(0) == 44444)
+    epochEndpoint.askSync[Long](IncrementAndGetEpoch)
+    assert(reader.next() == null)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org