You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by td...@apache.org on 2018/05/04 21:14:46 UTC
spark git commit: [SPARK-24039][SS] Do continuous processing writes
with multiple compute() calls
Repository: spark
Updated Branches:
refs/heads/master d04806a23 -> af4dc5028
[SPARK-24039][SS] Do continuous processing writes with multiple compute() calls
## What changes were proposed in this pull request?
Do continuous processing writes with multiple compute() calls.
The current strategy (before this PR) is hacky; we just call next() on an iterator which has already returned hasNext = false, knowing that all the nodes we whitelist handle this properly. This will have to be changed before we can support more complex query plans. (In particular, I have a WIP https://github.com/jose-torres/spark/pull/13 which should be able to support aggregates in a single partition with minimal additional work.)
Most of the changes here are just refactoring to accommodate the new model. The behavioral changes are:
* The writer now calls prev.compute(split, context) once per epoch within the epoch loop.
* ContinuousDataSourceRDD now spawns a ContinuousQueuedDataReader which is shared across multiple calls to compute() for the same partition.
## How was this patch tested?
existing unit tests
Author: Jose Torres <to...@gmail.com>
Closes #21200 from jose-torres/noAggr.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/af4dc502
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/af4dc502
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/af4dc502
Branch: refs/heads/master
Commit: af4dc50280ffcdeda208ef2dc5f8b843389732e5
Parents: d04806a
Author: Jose Torres <to...@gmail.com>
Authored: Fri May 4 14:14:40 2018 -0700
Committer: Tathagata Das <ta...@gmail.com>
Committed: Fri May 4 14:14:40 2018 -0700
----------------------------------------------------------------------
.../datasources/v2/DataSourceV2ScanExec.scala | 6 +-
.../continuous/ContinuousDataSourceRDD.scala | 114 ++++++++++
.../ContinuousDataSourceRDDIter.scala | 222 -------------------
.../continuous/ContinuousQueuedDataReader.scala | 211 ++++++++++++++++++
.../continuous/ContinuousWriteRDD.scala | 90 ++++++++
.../WriteToContinuousDataSourceExec.scala | 57 +----
.../ContinuousQueuedDataReaderSuite.scala | 167 ++++++++++++++
7 files changed, 592 insertions(+), 275 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/af4dc502/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
index 41bdda4..77cb707 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
@@ -96,7 +96,11 @@ case class DataSourceV2ScanExec(
sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
sparkContext.env)
.askSync[Unit](SetReaderPartitions(readerFactories.size))
- new ContinuousDataSourceRDD(sparkContext, sqlContext, readerFactories)
+ new ContinuousDataSourceRDD(
+ sparkContext,
+ sqlContext.conf.continuousStreamingExecutorQueueSize,
+ sqlContext.conf.continuousStreamingExecutorPollIntervalMs,
+ readerFactories)
.asInstanceOf[RDD[InternalRow]]
case r: SupportsScanColumnarBatch if r.enableBatchRead() =>
http://git-wip-us.apache.org/repos/asf/spark/blob/af4dc502/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala
new file mode 100644
index 0000000..0a3b9dc
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous
+
+import org.apache.spark._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader}
+import org.apache.spark.sql.sources.v2.reader._
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, PartitionOffset}
+import org.apache.spark.util.{NextIterator, ThreadUtils}
+
+class ContinuousDataSourceRDDPartition(
+ val index: Int,
+ val readerFactory: DataReaderFactory[UnsafeRow])
+ extends Partition with Serializable {
+
+ // This is semantically a lazy val - it's initialized once the first time a call to
+ // ContinuousDataSourceRDD.compute() needs to access it, so it can be shared across
+ // all compute() calls for a partition. This ensures that one compute() picks up where the
+ // previous one ended.
+ // We don't make it actually a lazy val because it needs input which isn't available here.
+ // This will only be initialized on the executors.
+ private[continuous] var queueReader: ContinuousQueuedDataReader = _
+}
+
+/**
+ * The bottom-most RDD of a continuous processing read task. Wraps a [[ContinuousQueuedDataReader]]
+ * to read from the remote source, and polls that queue for incoming rows.
+ *
+ * Note that continuous processing calls compute() multiple times, and the same
+ * [[ContinuousQueuedDataReader]] instance will/must be shared between each call for the same split.
+ */
+class ContinuousDataSourceRDD(
+ sc: SparkContext,
+ dataQueueSize: Int,
+ epochPollIntervalMs: Long,
+ @transient private val readerFactories: Seq[DataReaderFactory[UnsafeRow]])
+ extends RDD[UnsafeRow](sc, Nil) {
+
+ override protected def getPartitions: Array[Partition] = {
+ readerFactories.zipWithIndex.map {
+ case (readerFactory, index) => new ContinuousDataSourceRDDPartition(index, readerFactory)
+ }.toArray
+ }
+
+ /**
+ * Initialize the shared reader for this partition if needed, then read rows from it until
+ * it returns null to signal the end of the epoch.
+ */
+ override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = {
+ // If attempt number isn't 0, this is a task retry, which we don't support.
+ if (context.attemptNumber() != 0) {
+ throw new ContinuousTaskRetryException()
+ }
+
+ val readerForPartition = {
+ val partition = split.asInstanceOf[ContinuousDataSourceRDDPartition]
+ if (partition.queueReader == null) {
+ partition.queueReader =
+ new ContinuousQueuedDataReader(
+ partition.readerFactory, context, dataQueueSize, epochPollIntervalMs)
+ }
+
+ partition.queueReader
+ }
+
+ new NextIterator[UnsafeRow] {
+ override def getNext(): UnsafeRow = {
+ readerForPartition.next() match {
+ case null =>
+ finished = true
+ null
+ case row => row
+ }
+ }
+
+ override def close(): Unit = {}
+ }
+ }
+
+ override def getPreferredLocations(split: Partition): Seq[String] = {
+ split.asInstanceOf[ContinuousDataSourceRDDPartition].readerFactory.preferredLocations()
+ }
+}
+
+object ContinuousDataSourceRDD {
+ private[continuous] def getContinuousReader(
+ reader: DataReader[UnsafeRow]): ContinuousDataReader[_] = {
+ reader match {
+ case r: ContinuousDataReader[UnsafeRow] => r
+ case wrapped: RowToUnsafeDataReader =>
+ wrapped.rowReader.asInstanceOf[ContinuousDataReader[Row]]
+ case _ =>
+ throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}")
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/af4dc502/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
deleted file mode 100644
index 06754f0..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
+++ /dev/null
@@ -1,222 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming.continuous
-
-import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit}
-import java.util.concurrent.atomic.AtomicBoolean
-
-import scala.collection.JavaConverters._
-
-import org.apache.spark._
-import org.apache.spark.internal.Logging
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Row, SQLContext}
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDDPartition, RowToUnsafeDataReader}
-import org.apache.spark.sql.sources.v2.reader._
-import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, PartitionOffset}
-import org.apache.spark.util.ThreadUtils
-
-class ContinuousDataSourceRDD(
- sc: SparkContext,
- sqlContext: SQLContext,
- @transient private val readerFactories: Seq[DataReaderFactory[UnsafeRow]])
- extends RDD[UnsafeRow](sc, Nil) {
-
- private val dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize
- private val epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs
-
- override protected def getPartitions: Array[Partition] = {
- readerFactories.zipWithIndex.map {
- case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory)
- }.toArray
- }
-
- override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = {
- // If attempt number isn't 0, this is a task retry, which we don't support.
- if (context.attemptNumber() != 0) {
- throw new ContinuousTaskRetryException()
- }
-
- val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]]
- .readerFactory.createDataReader()
-
- val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)
-
- // This queue contains two types of messages:
- // * (null, null) representing an epoch boundary.
- // * (row, off) containing a data row and its corresponding PartitionOffset.
- val queue = new ArrayBlockingQueue[(UnsafeRow, PartitionOffset)](dataQueueSize)
-
- val epochPollFailed = new AtomicBoolean(false)
- val epochPollExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor(
- s"epoch-poll--$coordinatorId--${context.partitionId()}")
- val epochPollRunnable = new EpochPollRunnable(queue, context, epochPollFailed)
- epochPollExecutor.scheduleWithFixedDelay(
- epochPollRunnable, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS)
-
- // Important sequencing - we must get start offset before the data reader thread begins
- val startOffset = ContinuousDataSourceRDD.getBaseReader(reader).getOffset
-
- val dataReaderFailed = new AtomicBoolean(false)
- val dataReaderThread = new DataReaderThread(reader, queue, context, dataReaderFailed)
- dataReaderThread.setDaemon(true)
- dataReaderThread.start()
-
- context.addTaskCompletionListener(_ => {
- dataReaderThread.interrupt()
- epochPollExecutor.shutdown()
- })
-
- val epochEndpoint = EpochCoordinatorRef.get(coordinatorId, SparkEnv.get)
- new Iterator[UnsafeRow] {
- private val POLL_TIMEOUT_MS = 1000
-
- private var currentEntry: (UnsafeRow, PartitionOffset) = _
- private var currentOffset: PartitionOffset = startOffset
- private var currentEpoch =
- context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
-
- override def hasNext(): Boolean = {
- while (currentEntry == null) {
- if (context.isInterrupted() || context.isCompleted()) {
- currentEntry = (null, null)
- }
- if (dataReaderFailed.get()) {
- throw new SparkException("data read failed", dataReaderThread.failureReason)
- }
- if (epochPollFailed.get()) {
- throw new SparkException("epoch poll failed", epochPollRunnable.failureReason)
- }
- currentEntry = queue.poll(POLL_TIMEOUT_MS, TimeUnit.MILLISECONDS)
- }
-
- currentEntry match {
- // epoch boundary marker
- case (null, null) =>
- epochEndpoint.send(ReportPartitionOffset(
- context.partitionId(),
- currentEpoch,
- currentOffset))
- currentEpoch += 1
- currentEntry = null
- false
- // real row
- case (_, offset) =>
- currentOffset = offset
- true
- }
- }
-
- override def next(): UnsafeRow = {
- if (currentEntry == null) throw new NoSuchElementException("No current row was set")
- val r = currentEntry._1
- currentEntry = null
- r
- }
- }
- }
-
- override def getPreferredLocations(split: Partition): Seq[String] = {
- split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readerFactory.preferredLocations()
- }
-}
-
-case class EpochPackedPartitionOffset(epoch: Long) extends PartitionOffset
-
-class EpochPollRunnable(
- queue: BlockingQueue[(UnsafeRow, PartitionOffset)],
- context: TaskContext,
- failedFlag: AtomicBoolean)
- extends Thread with Logging {
- private[continuous] var failureReason: Throwable = _
-
- private val epochEndpoint = EpochCoordinatorRef.get(
- context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get)
- private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
-
- override def run(): Unit = {
- try {
- val newEpoch = epochEndpoint.askSync[Long](GetCurrentEpoch)
- for (i <- currentEpoch to newEpoch - 1) {
- queue.put((null, null))
- logDebug(s"Sent marker to start epoch ${i + 1}")
- }
- currentEpoch = newEpoch
- } catch {
- case t: Throwable =>
- failureReason = t
- failedFlag.set(true)
- throw t
- }
- }
-}
-
-class DataReaderThread(
- reader: DataReader[UnsafeRow],
- queue: BlockingQueue[(UnsafeRow, PartitionOffset)],
- context: TaskContext,
- failedFlag: AtomicBoolean)
- extends Thread(
- s"continuous-reader--${context.partitionId()}--" +
- s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") {
- private[continuous] var failureReason: Throwable = _
-
- override def run(): Unit = {
- TaskContext.setTaskContext(context)
- val baseReader = ContinuousDataSourceRDD.getBaseReader(reader)
- try {
- while (!context.isInterrupted && !context.isCompleted()) {
- if (!reader.next()) {
- // Check again, since reader.next() might have blocked through an incoming interrupt.
- if (!context.isInterrupted && !context.isCompleted()) {
- throw new IllegalStateException(
- "Continuous reader reported no elements! Reader should have blocked waiting.")
- } else {
- return
- }
- }
-
- queue.put((reader.get().copy(), baseReader.getOffset))
- }
- } catch {
- case _: InterruptedException if context.isInterrupted() =>
- // Continuous shutdown always involves an interrupt; do nothing and shut down quietly.
-
- case t: Throwable =>
- failureReason = t
- failedFlag.set(true)
- // Don't rethrow the exception in this thread. It's not needed, and the default Spark
- // exception handler will kill the executor.
- } finally {
- reader.close()
- }
- }
-}
-
-object ContinuousDataSourceRDD {
- private[continuous] def getBaseReader(reader: DataReader[UnsafeRow]): ContinuousDataReader[_] = {
- reader match {
- case r: ContinuousDataReader[UnsafeRow] => r
- case wrapped: RowToUnsafeDataReader =>
- wrapped.rowReader.asInstanceOf[ContinuousDataReader[Row]]
- case _ =>
- throw new IllegalStateException(s"Unknown continuous reader type ${reader.getClass}")
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/af4dc502/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala
new file mode 100644
index 0000000..01a999f
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous
+
+import java.io.Closeable
+import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue, TimeUnit}
+import java.util.concurrent.atomic.AtomicBoolean
+
+import scala.util.control.NonFatal
+
+import org.apache.spark.{Partition, SparkEnv, SparkException, TaskContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory}
+import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset
+import org.apache.spark.util.ThreadUtils
+
+/**
+ * A wrapper for a continuous processing data reader, including a reading queue and epoch markers.
+ *
+ * This will be instantiated once per partition - successive calls to compute() in the
+ * [[ContinuousDataSourceRDD]] will reuse the same reader. This is required to get continuity of
+ * offsets across epochs. Each compute() should call the next() method here until null is returned.
+ */
+class ContinuousQueuedDataReader(
+ factory: DataReaderFactory[UnsafeRow],
+ context: TaskContext,
+ dataQueueSize: Int,
+ epochPollIntervalMs: Long) extends Closeable {
+ private val reader = factory.createDataReader()
+
+ // Important sequencing - we must get our starting point before the provider threads start running
+ private var currentOffset: PartitionOffset =
+ ContinuousDataSourceRDD.getContinuousReader(reader).getOffset
+ private var currentEpoch: Long =
+ context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
+
+ /**
+ * The record types in the read buffer.
+ */
+ sealed trait ContinuousRecord
+ case object EpochMarker extends ContinuousRecord
+ case class ContinuousRow(row: UnsafeRow, offset: PartitionOffset) extends ContinuousRecord
+
+ private val queue = new ArrayBlockingQueue[ContinuousRecord](dataQueueSize)
+
+ private val coordinatorId = context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)
+ private val epochCoordEndpoint = EpochCoordinatorRef.get(
+ context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get)
+
+ private val epochMarkerExecutor = ThreadUtils.newDaemonSingleThreadScheduledExecutor(
+ s"epoch-poll--$coordinatorId--${context.partitionId()}")
+ private val epochMarkerGenerator = new EpochMarkerGenerator
+ epochMarkerExecutor.scheduleWithFixedDelay(
+ epochMarkerGenerator, 0, epochPollIntervalMs, TimeUnit.MILLISECONDS)
+
+ private val dataReaderThread = new DataReaderThread
+ dataReaderThread.setDaemon(true)
+ dataReaderThread.start()
+
+ context.addTaskCompletionListener(_ => {
+ this.close()
+ })
+
+ private def shouldStop() = {
+ context.isInterrupted() || context.isCompleted()
+ }
+
+ /**
+ * Return the next UnsafeRow to be read in the current epoch, or null if the epoch is done.
+ *
+ * After returning null, the [[ContinuousDataSourceRDD]] compute() for the following epoch
+ * will call next() again to start getting rows.
+ */
+ def next(): UnsafeRow = {
+ val POLL_TIMEOUT_MS = 1000
+ var currentEntry: ContinuousRecord = null
+
+ while (currentEntry == null) {
+ if (shouldStop()) {
+ // Force the epoch to end here. The writer will notice the context is interrupted
+ // or completed and not start a new one. This makes it possible to achieve clean
+ // shutdown of the streaming query.
+ // TODO: The obvious generalization of this logic to multiple stages won't work. It's
+ // invalid to send an epoch marker from the bottom of a task if all its child tasks
+ // haven't sent one.
+ currentEntry = EpochMarker
+ } else {
+ if (dataReaderThread.failureReason != null) {
+ throw new SparkException("Data read failed", dataReaderThread.failureReason)
+ }
+ if (epochMarkerGenerator.failureReason != null) {
+ throw new SparkException(
+ "Epoch marker generation failed",
+ epochMarkerGenerator.failureReason)
+ }
+ currentEntry = queue.poll(POLL_TIMEOUT_MS, TimeUnit.MILLISECONDS)
+ }
+ }
+
+ currentEntry match {
+ case EpochMarker =>
+ epochCoordEndpoint.send(ReportPartitionOffset(
+ context.partitionId(), currentEpoch, currentOffset))
+ currentEpoch += 1
+ null
+ case ContinuousRow(row, offset) =>
+ currentOffset = offset
+ row
+ }
+ }
+
+ override def close(): Unit = {
+ dataReaderThread.interrupt()
+ epochMarkerExecutor.shutdown()
+ }
+
+ /**
+ * The data component of [[ContinuousQueuedDataReader]]. Pushes (row, offset) to the queue when
+ * a new row arrives to the [[DataReader]].
+ */
+ class DataReaderThread extends Thread(
+ s"continuous-reader--${context.partitionId()}--" +
+ s"${context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY)}") with Logging {
+ @volatile private[continuous] var failureReason: Throwable = _
+
+ override def run(): Unit = {
+ TaskContext.setTaskContext(context)
+ val baseReader = ContinuousDataSourceRDD.getContinuousReader(reader)
+ try {
+ while (!shouldStop()) {
+ if (!reader.next()) {
+ // Check again, since reader.next() might have blocked through an incoming interrupt.
+ if (!shouldStop()) {
+ throw new IllegalStateException(
+ "Continuous reader reported no elements! Reader should have blocked waiting.")
+ } else {
+ return
+ }
+ }
+
+ queue.put(ContinuousRow(reader.get().copy(), baseReader.getOffset))
+ }
+ } catch {
+ case _: InterruptedException =>
+ // Continuous shutdown always involves an interrupt; do nothing and shut down quietly.
+ logInfo(s"shutting down interrupted data reader thread $getName")
+
+ case NonFatal(t) =>
+ failureReason = t
+ logWarning("data reader thread failed", t)
+ // If we throw from this thread, we may kill the executor. Let the parent thread handle
+ // it.
+
+ case t: Throwable =>
+ failureReason = t
+ throw t
+ } finally {
+ reader.close()
+ }
+ }
+ }
+
+ /**
+ * The epoch marker component of [[ContinuousQueuedDataReader]]. Populates the queue with
+ * EpochMarker when a new epoch marker arrives.
+ */
+ class EpochMarkerGenerator extends Runnable with Logging {
+ @volatile private[continuous] var failureReason: Throwable = _
+
+ private val epochCoordEndpoint = EpochCoordinatorRef.get(
+ context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), SparkEnv.get)
+ // Note that this is *not* the same as the currentEpoch in [[ContinuousDataQueuedReader]]! That
+ // field represents the epoch wrt the data being processed. The currentEpoch here is just a
+ // counter to ensure we send the appropriate number of markers if we fall behind the driver.
+ private var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
+
+ override def run(): Unit = {
+ try {
+ val newEpoch = epochCoordEndpoint.askSync[Long](GetCurrentEpoch)
+ // It's possible to fall more than 1 epoch behind if a GetCurrentEpoch RPC ends up taking
+ // a while. We catch up by injecting enough epoch markers immediately to catch up. This will
+ // result in some epochs being empty for this partition, but that's fine.
+ for (i <- currentEpoch to newEpoch - 1) {
+ queue.put(EpochMarker)
+ logDebug(s"Sent marker to start epoch ${i + 1}")
+ }
+ currentEpoch = newEpoch
+ } catch {
+ case t: Throwable =>
+ failureReason = t
+ throw t
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/af4dc502/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala
new file mode 100644
index 0000000..91f1576
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous
+
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.{Partition, SparkEnv, TaskContext}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo}
+import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage}
+import org.apache.spark.util.Utils
+
+/**
+ * The RDD writing to a sink in continuous processing.
+ *
+ * Within each task, we repeatedly call prev.compute(). Each resulting iterator contains the data
+ * to be written for one epoch, which we commit and forward to the driver.
+ *
+ * We keep repeating prev.compute() and writing new epochs until the query is shut down.
+ */
+class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactory[InternalRow])
+ extends RDD[Unit](prev) {
+
+ override val partitioner = prev.partitioner
+
+ override def getPartitions: Array[Partition] = prev.partitions
+
+ override def compute(split: Partition, context: TaskContext): Iterator[Unit] = {
+ val epochCoordinator = EpochCoordinatorRef.get(
+ context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
+ SparkEnv.get)
+ var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
+
+ while (!context.isInterrupted() && !context.isCompleted()) {
+ var dataWriter: DataWriter[InternalRow] = null
+ // write the data and commit this writer.
+ Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
+ try {
+ val dataIterator = prev.compute(split, context)
+ dataWriter = writeTask.createDataWriter(
+ context.partitionId(), context.attemptNumber(), currentEpoch)
+ while (dataIterator.hasNext) {
+ dataWriter.write(dataIterator.next())
+ }
+ logInfo(s"Writer for partition ${context.partitionId()} " +
+ s"in epoch $currentEpoch is committing.")
+ val msg = dataWriter.commit()
+ epochCoordinator.send(
+ CommitPartitionEpoch(context.partitionId(), currentEpoch, msg)
+ )
+ logInfo(s"Writer for partition ${context.partitionId()} " +
+ s"in epoch $currentEpoch committed.")
+ currentEpoch += 1
+ } catch {
+ case _: InterruptedException =>
+ // Continuous shutdown always involves an interrupt. Just finish the task.
+ }
+ })(catchBlock = {
+ // If there is an error, abort this writer. We enter this callback in the middle of
+ // rethrowing an exception, so compute() will stop executing at this point.
+ logError(s"Writer for partition ${context.partitionId()} is aborting.")
+ if (dataWriter != null) dataWriter.abort()
+ logError(s"Writer for partition ${context.partitionId()} aborted.")
+ })
+ }
+
+ Iterator()
+ }
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ prev = null
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/af4dc502/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala
index ba88ae1..e0af3a2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala
@@ -46,24 +46,19 @@ case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPla
case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema)
}
- val rdd = query.execute()
+ val rdd = new ContinuousWriteRDD(query.execute(), writerFactory)
logInfo(s"Start processing data source writer: $writer. " +
- s"The input RDD has ${rdd.getNumPartitions} partitions.")
- // Let the epoch coordinator know how many partitions the write RDD has.
+ s"The input RDD has ${rdd.partitions.length} partitions.")
EpochCoordinatorRef.get(
- sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
- sparkContext.env)
+ sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
+ sparkContext.env)
.askSync[Unit](SetWriterPartitions(rdd.getNumPartitions))
try {
// Force the RDD to run so continuous processing starts; no data is actually being collected
// to the driver, as ContinuousWriteRDD outputs nothing.
- sparkContext.runJob(
- rdd,
- (context: TaskContext, iter: Iterator[InternalRow]) =>
- WriteToContinuousDataSourceExec.run(writerFactory, context, iter),
- rdd.partitions.indices)
+ rdd.collect()
} catch {
case _: InterruptedException =>
// Interruption is how continuous queries are ended, so accept and ignore the exception.
@@ -80,45 +75,3 @@ case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPla
sparkContext.emptyRDD
}
}
-
-object WriteToContinuousDataSourceExec extends Logging {
- def run(
- writeTask: DataWriterFactory[InternalRow],
- context: TaskContext,
- iter: Iterator[InternalRow]): Unit = {
- val epochCoordinator = EpochCoordinatorRef.get(
- context.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
- SparkEnv.get)
- var currentEpoch = context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
-
- do {
- var dataWriter: DataWriter[InternalRow] = null
- // write the data and commit this writer.
- Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
- try {
- dataWriter = writeTask.createDataWriter(
- context.partitionId(), context.attemptNumber(), currentEpoch)
- while (iter.hasNext) {
- dataWriter.write(iter.next())
- }
- logInfo(s"Writer for partition ${context.partitionId()} is committing.")
- val msg = dataWriter.commit()
- logInfo(s"Writer for partition ${context.partitionId()} committed.")
- epochCoordinator.send(
- CommitPartitionEpoch(context.partitionId(), currentEpoch, msg)
- )
- currentEpoch += 1
- } catch {
- case _: InterruptedException =>
- // Continuous shutdown always involves an interrupt. Just finish the task.
- }
- })(catchBlock = {
- // If there is an error, abort this writer. We enter this callback in the middle of
- // rethrowing an exception, so runContinuous will stop executing at this point.
- logError(s"Writer for partition ${context.partitionId()} is aborting.")
- if (dataWriter != null) dataWriter.abort()
- logError(s"Writer for partition ${context.partitionId()} aborted.")
- })
- } while (!context.isInterrupted())
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/af4dc502/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala
new file mode 100644
index 0000000..e755625
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming.continuous
+
+import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue}
+
+import org.mockito.{ArgumentCaptor, Matchers}
+import org.mockito.Mockito._
+import org.scalatest.mockito.MockitoSugar
+
+import org.apache.spark.{SparkEnv, SparkFunSuite, TaskContext}
+import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv}
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.execution.streaming.continuous._
+import org.apache.spark.sql.sources.v2.reader.DataReaderFactory
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, PartitionOffset}
+import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
+import org.apache.spark.sql.streaming.StreamTest
+import org.apache.spark.sql.types.{DataType, IntegerType}
+
+class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar {
+ case class LongPartitionOffset(offset: Long) extends PartitionOffset
+
+ val coordinatorId = s"${getClass.getSimpleName}-epochCoordinatorIdForUnitTest"
+ val startEpoch = 0
+
+ var epochEndpoint: RpcEndpointRef = _
+
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ epochEndpoint = EpochCoordinatorRef.create(
+ mock[StreamWriter],
+ mock[ContinuousReader],
+ mock[ContinuousExecution],
+ coordinatorId,
+ startEpoch,
+ spark,
+ SparkEnv.get)
+ }
+
+ override def afterEach(): Unit = {
+ SparkEnv.get.rpcEnv.stop(epochEndpoint)
+ epochEndpoint = null
+ super.afterEach()
+ }
+
+
+ private val mockContext = mock[TaskContext]
+ when(mockContext.getLocalProperty(ContinuousExecution.START_EPOCH_KEY))
+ .thenReturn(startEpoch.toString)
+ when(mockContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY))
+ .thenReturn(coordinatorId)
+
+ /**
+ * Set up a ContinuousQueuedDataReader for testing. The blocking queue can be used to send
+ * rows to the wrapped data reader.
+ */
+ private def setup(): (BlockingQueue[UnsafeRow], ContinuousQueuedDataReader) = {
+ val queue = new ArrayBlockingQueue[UnsafeRow](1024)
+ val factory = new DataReaderFactory[UnsafeRow] {
+ override def createDataReader() = new ContinuousDataReader[UnsafeRow] {
+ var index = -1
+ var curr: UnsafeRow = _
+
+ override def next() = {
+ curr = queue.take()
+ index += 1
+ true
+ }
+
+ override def get = curr
+
+ override def getOffset = LongPartitionOffset(index)
+
+ override def close() = {}
+ }
+ }
+ val reader = new ContinuousQueuedDataReader(
+ factory,
+ mockContext,
+ dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize,
+ epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs)
+
+ (queue, reader)
+ }
+
+ private def unsafeRow(value: Int) = {
+ UnsafeProjection.create(Array(IntegerType : DataType))(
+ new GenericInternalRow(Array(value: Any)))
+ }
+
+ test("basic data read") {
+ val (input, reader) = setup()
+
+ input.add(unsafeRow(12345))
+ assert(reader.next().getInt(0) == 12345)
+ }
+
+ test("basic epoch marker") {
+ val (input, reader) = setup()
+
+ epochEndpoint.askSync[Long](IncrementAndGetEpoch)
+ assert(reader.next() == null)
+ }
+
+ test("new rows after markers") {
+ val (input, reader) = setup()
+
+ epochEndpoint.askSync[Long](IncrementAndGetEpoch)
+ epochEndpoint.askSync[Long](IncrementAndGetEpoch)
+ epochEndpoint.askSync[Long](IncrementAndGetEpoch)
+ assert(reader.next() == null)
+ assert(reader.next() == null)
+ assert(reader.next() == null)
+ input.add(unsafeRow(11111))
+ input.add(unsafeRow(22222))
+ assert(reader.next().getInt(0) == 11111)
+ assert(reader.next().getInt(0) == 22222)
+ }
+
+ test("new markers after rows") {
+ val (input, reader) = setup()
+
+ input.add(unsafeRow(11111))
+ input.add(unsafeRow(22222))
+ assert(reader.next().getInt(0) == 11111)
+ assert(reader.next().getInt(0) == 22222)
+ epochEndpoint.askSync[Long](IncrementAndGetEpoch)
+ epochEndpoint.askSync[Long](IncrementAndGetEpoch)
+ epochEndpoint.askSync[Long](IncrementAndGetEpoch)
+ assert(reader.next() == null)
+ assert(reader.next() == null)
+ assert(reader.next() == null)
+ }
+
+ test("alternating markers and rows") {
+ val (input, reader) = setup()
+
+ input.add(unsafeRow(11111))
+ assert(reader.next().getInt(0) == 11111)
+ input.add(unsafeRow(22222))
+ assert(reader.next().getInt(0) == 22222)
+ epochEndpoint.askSync[Long](IncrementAndGetEpoch)
+ assert(reader.next() == null)
+ input.add(unsafeRow(33333))
+ assert(reader.next().getInt(0) == 33333)
+ input.add(unsafeRow(44444))
+ assert(reader.next().getInt(0) == 44444)
+ epochEndpoint.askSync[Long](IncrementAndGetEpoch)
+ assert(reader.next() == null)
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org