You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by td...@apache.org on 2018/01/11 18:51:01 UTC
[1/2] spark git commit: [SPARK-22908] Add kafka source and sink for
continuous processing.
Repository: spark
Updated Branches:
refs/heads/master 0b2eefb67 -> 6f7aaed80
http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 3304f36..97f12ff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -255,17 +255,24 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
}
}
- case _ => throw new AnalysisException(s"$cls does not support data writing.")
+ // Streaming also uses the data source V2 API. So it may be that the data source implements
+ // v2, but has no v2 implementation for batch writes. In that case, we fall back to saving
+ // as though it's a V1 source.
+ case _ => saveToV1Source()
}
} else {
- // Code path for data source v1.
- runCommand(df.sparkSession, "save") {
- DataSource(
- sparkSession = df.sparkSession,
- className = source,
- partitionColumns = partitioningColumns.getOrElse(Nil),
- options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan))
- }
+ saveToV1Source()
+ }
+ }
+
+ private def saveToV1Source(): Unit = {
+ // Code path for data source v1.
+ runCommand(df.sparkSession, "save") {
+ DataSource(
+ sparkSession = df.sparkSession,
+ className = source,
+ partitionColumns = partitioningColumns.getOrElse(Nil),
+ options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan))
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala
index f0bdf84..a4a857f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala
@@ -81,9 +81,11 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan)
(index, message: WriterCommitMessage) => messages(index) = message
)
- logInfo(s"Data source writer $writer is committing.")
- writer.commit(messages)
- logInfo(s"Data source writer $writer committed.")
+ if (!writer.isInstanceOf[ContinuousWriter]) {
+ logInfo(s"Data source writer $writer is committing.")
+ writer.commit(messages)
+ logInfo(s"Data source writer $writer committed.")
+ }
} catch {
case _: InterruptedException if writer.isInstanceOf[ContinuousWriter] =>
// Interruption is how continuous queries are ended, so accept and ignore the exception.
http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 24a8b00..cf27e1a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -142,7 +142,8 @@ abstract class StreamExecution(
override val id: UUID = UUID.fromString(streamMetadata.id)
- override val runId: UUID = UUID.randomUUID
+ override def runId: UUID = currentRunId
+ protected var currentRunId = UUID.randomUUID
/**
* Pretty identified string of printing in logs. Format is
@@ -418,11 +419,17 @@ abstract class StreamExecution(
* Blocks the current thread until processing for data from the given `source` has reached at
* least the given `Offset`. This method is intended for use primarily when writing tests.
*/
- private[sql] def awaitOffset(source: BaseStreamingSource, newOffset: Offset): Unit = {
+ private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset): Unit = {
assertAwaitThread()
def notDone = {
val localCommittedOffsets = committedOffsets
- !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset
+ if (sources == null) {
+ // sources might not be initialized yet
+ false
+ } else {
+ val source = sources(sourceIndex)
+ !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset
+ }
}
while (notDone) {
@@ -436,7 +443,7 @@ abstract class StreamExecution(
awaitProgressLock.unlock()
}
}
- logDebug(s"Unblocked at $newOffset for $source")
+ logDebug(s"Unblocked at $newOffset for ${sources(sourceIndex)}")
}
/** A flag to indicate that a batch has completed with no new data available. */
http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
index d79e4bd..e700aa4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
@@ -77,7 +77,6 @@ class ContinuousDataSourceRDD(
dataReaderThread.start()
context.addTaskCompletionListener(_ => {
- reader.close()
dataReaderThread.interrupt()
epochPollExecutor.shutdown()
})
@@ -201,6 +200,8 @@ class DataReaderThread(
failedFlag.set(true)
// Don't rethrow the exception in this thread. It's not needed, and the default Spark
// exception handler will kill the executor.
+ } finally {
+ reader.close()
}
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
index 9657b5e..667410e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
@@ -17,7 +17,9 @@
package org.apache.spark.sql.execution.streaming.continuous
+import java.util.UUID
import java.util.concurrent.TimeUnit
+import java.util.function.UnaryOperator
import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, Map => MutableMap}
@@ -52,7 +54,7 @@ class ContinuousExecution(
sparkSession, name, checkpointRoot, analyzedPlan, sink,
trigger, triggerClock, outputMode, deleteCheckpointOnStop) {
- @volatile protected var continuousSources: Seq[ContinuousReader] = Seq.empty
+ @volatile protected var continuousSources: Seq[ContinuousReader] = _
override protected def sources: Seq[BaseStreamingSource] = continuousSources
override lazy val logicalPlan: LogicalPlan = {
@@ -78,15 +80,17 @@ class ContinuousExecution(
}
override protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = {
- do {
- try {
- runContinuous(sparkSessionForStream)
- } catch {
- case _: InterruptedException if state.get().equals(RECONFIGURING) =>
- // swallow exception and run again
- state.set(ACTIVE)
+ val stateUpdate = new UnaryOperator[State] {
+ override def apply(s: State) = s match {
+ // If we ended the query to reconfigure, reset the state to active.
+ case RECONFIGURING => ACTIVE
+ case _ => s
}
- } while (state.get() == ACTIVE)
+ }
+
+ do {
+ runContinuous(sparkSessionForStream)
+ } while (state.updateAndGet(stateUpdate) == ACTIVE)
}
/**
@@ -120,12 +124,16 @@ class ContinuousExecution(
}
committedOffsets = nextOffsets.toStreamProgress(sources)
- // Forcibly align commit and offset logs by slicing off any spurious offset logs from
- // a previous run. We can't allow commits to an epoch that a previous run reached but
- // this run has not.
- offsetLog.purgeAfter(latestEpochId)
+ // Get to an epoch ID that has definitely never been sent to a sink before. Since sink
+ // commit happens between offset log write and commit log write, this means an epoch ID
+ // which is not in the offset log.
+ val (latestOffsetEpoch, _) = offsetLog.getLatest().getOrElse {
+ throw new IllegalStateException(
+ s"Offset log had no latest element. This shouldn't be possible because nextOffsets is" +
+ s"an element.")
+ }
+ currentBatchId = latestOffsetEpoch + 1
- currentBatchId = latestEpochId + 1
logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets")
nextOffsets
case None =>
@@ -141,6 +149,7 @@ class ContinuousExecution(
* @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with.
*/
private def runContinuous(sparkSessionForQuery: SparkSession): Unit = {
+ currentRunId = UUID.randomUUID
// A list of attributes that will need to be updated.
val replacements = new ArrayBuffer[(Attribute, Attribute)]
// Translate from continuous relation to the underlying data source.
@@ -225,13 +234,11 @@ class ContinuousExecution(
triggerExecutor.execute(() => {
startTrigger()
- if (reader.needsReconfiguration()) {
- state.set(RECONFIGURING)
+ if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) {
stopSources()
if (queryExecutionThread.isAlive) {
sparkSession.sparkContext.cancelJobGroup(runId.toString)
queryExecutionThread.interrupt()
- // No need to join - this thread is about to end anyway.
}
false
} else if (isActive) {
@@ -259,6 +266,7 @@ class ContinuousExecution(
sparkSessionForQuery, lastExecution)(lastExecution.toRdd)
}
} finally {
+ epochEndpoint.askSync[Unit](StopContinuousExecutionWrites)
SparkEnv.get.rpcEnv.stop(epochEndpoint)
epochUpdateThread.interrupt()
@@ -273,17 +281,22 @@ class ContinuousExecution(
epoch: Long, reader: ContinuousReader, partitionOffsets: Seq[PartitionOffset]): Unit = {
assert(continuousSources.length == 1, "only one continuous source supported currently")
- if (partitionOffsets.contains(null)) {
- // If any offset is null, that means the corresponding partition hasn't seen any data yet, so
- // there's nothing meaningful to add to the offset log.
- }
val globalOffset = reader.mergeOffsets(partitionOffsets.toArray)
- synchronized {
- if (queryExecutionThread.isAlive) {
- offsetLog.add(epoch, OffsetSeq.fill(globalOffset))
- } else {
- return
- }
+ val oldOffset = synchronized {
+ offsetLog.add(epoch, OffsetSeq.fill(globalOffset))
+ offsetLog.get(epoch - 1)
+ }
+
+ // If offset hasn't changed since last epoch, there's been no new data.
+ if (oldOffset.contains(OffsetSeq.fill(globalOffset))) {
+ noNewData = true
+ }
+
+ awaitProgressLock.lock()
+ try {
+ awaitProgressLockCondition.signalAll()
+ } finally {
+ awaitProgressLock.unlock()
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
index 98017c3..40dcbec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
@@ -39,6 +39,15 @@ private[continuous] sealed trait EpochCoordinatorMessage extends Serializable
*/
private[sql] case object IncrementAndGetEpoch extends EpochCoordinatorMessage
+/**
+ * The RpcEndpoint stop() will wait to clear out the message queue before terminating the
+ * object. This can lead to a race condition where the query restarts at epoch n, a new
+ * EpochCoordinator starts at epoch n, and then the old epoch coordinator commits epoch n + 1.
+ * The framework doesn't provide a handle to wait on the message queue, so we use a synchronous
+ * message to stop any writes to the ContinuousExecution object.
+ */
+private[sql] case object StopContinuousExecutionWrites extends EpochCoordinatorMessage
+
// Init messages
/**
* Set the reader and writer partition counts. Tasks may not be started until the coordinator
@@ -116,6 +125,8 @@ private[continuous] class EpochCoordinator(
override val rpcEnv: RpcEnv)
extends ThreadSafeRpcEndpoint with Logging {
+ private var queryWritesStopped: Boolean = false
+
private var numReaderPartitions: Int = _
private var numWriterPartitions: Int = _
@@ -147,12 +158,16 @@ private[continuous] class EpochCoordinator(
partitionCommits.remove(k)
}
for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) {
- partitionCommits.remove(k)
+ partitionOffsets.remove(k)
}
}
}
override def receive: PartialFunction[Any, Unit] = {
+ // If we just drop these messages, we won't do any writes to the query. The lame duck tasks
+ // won't shed errors or anything.
+ case _ if queryWritesStopped => ()
+
case CommitPartitionEpoch(partitionId, epoch, message) =>
logDebug(s"Got commit from partition $partitionId at epoch $epoch: $message")
if (!partitionCommits.isDefinedAt((epoch, partitionId))) {
@@ -188,5 +203,9 @@ private[continuous] class EpochCoordinator(
case SetWriterPartitions(numPartitions) =>
numWriterPartitions = numPartitions
context.reply(())
+
+ case StopContinuousExecutionWrites =>
+ queryWritesStopped = true
+ context.reply(())
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index db588ae..b5b4a05 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2}
+import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport
/**
* Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems,
@@ -279,18 +280,29 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
useTempCheckpointLocation = true,
trigger = trigger)
} else {
- val dataSource =
- DataSource(
- df.sparkSession,
- className = source,
- options = extraOptions.toMap,
- partitionColumns = normalizedParCols.getOrElse(Nil))
+ val sink = trigger match {
+ case _: ContinuousTrigger =>
+ val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf)
+ ds.newInstance() match {
+ case w: ContinuousWriteSupport => w
+ case _ => throw new AnalysisException(
+ s"Data source $source does not support continuous writing")
+ }
+ case _ =>
+ val ds = DataSource(
+ df.sparkSession,
+ className = source,
+ options = extraOptions.toMap,
+ partitionColumns = normalizedParCols.getOrElse(Nil))
+ ds.createSink(outputMode)
+ }
+
df.sparkSession.sessionState.streamingQueryManager.startQuery(
extraOptions.get("queryName"),
extraOptions.get("checkpointLocation"),
df,
extraOptions.toMap,
- dataSource.createSink(outputMode),
+ sink,
outputMode,
useTempCheckpointLocation = source == "console",
recoverFromCheckpointLocation = true,
http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index d46461f..0762895 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -38,8 +38,9 @@ import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row}
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.streaming._
-import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch}
+import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger, EpochCoordinatorRef, IncrementAndGetEpoch}
import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2
import org.apache.spark.sql.execution.streaming.state.StateStore
import org.apache.spark.sql.streaming.StreamingQueryListener._
@@ -80,6 +81,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
StateStore.stop() // stop the state store maintenance thread and unload store providers
}
+ protected val defaultTrigger = Trigger.ProcessingTime(0)
+ protected val defaultUseV2Sink = false
+
/** How long to wait for an active stream to catch up when checking a result. */
val streamingTimeout = 10.seconds
@@ -189,7 +193,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
/** Starts the stream, resuming if data has already been processed. It must not be running. */
case class StartStream(
- trigger: Trigger = Trigger.ProcessingTime(0),
+ trigger: Trigger = defaultTrigger,
triggerClock: Clock = new SystemClock,
additionalConfs: Map[String, String] = Map.empty,
checkpointLocation: String = null)
@@ -276,7 +280,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
def testStream(
_stream: Dataset[_],
outputMode: OutputMode = OutputMode.Append,
- useV2Sink: Boolean = false)(actions: StreamAction*): Unit = synchronized {
+ useV2Sink: Boolean = defaultUseV2Sink)(actions: StreamAction*): Unit = synchronized {
import org.apache.spark.sql.streaming.util.StreamManualClock
// `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently
@@ -403,18 +407,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = {
verify(currentStream != null, "stream not running")
- // Get the map of source index to the current source objects
- val indexToSource = currentStream
- .logicalPlan
- .collect { case StreamingExecutionRelation(s, _) => s }
- .zipWithIndex
- .map(_.swap)
- .toMap
// Block until all data added has been processed for all the source
awaiting.foreach { case (sourceIndex, offset) =>
failAfter(streamingTimeout) {
- currentStream.awaitOffset(indexToSource(sourceIndex), offset)
+ currentStream.awaitOffset(sourceIndex, offset)
}
}
@@ -473,6 +470,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
// after starting the query.
try {
currentStream.awaitInitialization(streamingTimeout.toMillis)
+ currentStream match {
+ case s: ContinuousExecution => eventually("IncrementalExecution was not created") {
+ s.lastExecution.executedPlan // will fail if lastExecution is null
+ }
+ case _ =>
+ }
} catch {
case _: StreamingQueryException =>
// Ignore the exception. `StopStream` or `ExpectFailure` will catch it as well.
@@ -600,7 +603,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
def findSourceIndex(plan: LogicalPlan): Option[Int] = {
plan
- .collect { case StreamingExecutionRelation(s, _) => s }
+ .collect {
+ case StreamingExecutionRelation(s, _) => s
+ case DataSourceV2Relation(_, r) => r
+ }
.zipWithIndex
.find(_._1 == source)
.map(_._2)
@@ -613,9 +619,13 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
findSourceIndex(query.logicalPlan)
}.orElse {
findSourceIndex(stream.logicalPlan)
+ }.orElse {
+ queryToUse.flatMap { q =>
+ findSourceIndex(q.lastExecution.logical)
+ }
}.getOrElse {
throw new IllegalArgumentException(
- "Could find index of the source to which data was added")
+ "Could not find index of the source to which data was added")
}
// Store the expected offset of added data to wait for it later
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org
[2/2] spark git commit: [SPARK-22908] Add kafka source and sink for
continuous processing.
Posted by td...@apache.org.
[SPARK-22908] Add kafka source and sink for continuous processing.
## What changes were proposed in this pull request?
Add kafka source and sink for continuous processing. This involves two small changes to the execution engine:
* Bring data reader close() into the normal data reader thread to avoid thread safety issues.
* Fix up the semantics of the RECONFIGURING StreamExecution state. State updates are now atomic, and we don't have to deal with swallowing an exception.
## How was this patch tested?
new unit tests
Author: Jose Torres <jo...@databricks.com>
Closes #20096 from jose-torres/continuous-kafka.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6f7aaed8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6f7aaed8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6f7aaed8
Branch: refs/heads/master
Commit: 6f7aaed805070d29dcba32e04ca7a1f581fa54b9
Parents: 0b2eefb
Author: Jose Torres <jo...@databricks.com>
Authored: Thu Jan 11 10:52:12 2018 -0800
Committer: Tathagata Das <ta...@gmail.com>
Committed: Thu Jan 11 10:52:12 2018 -0800
----------------------------------------------------------------------
.../sql/kafka010/KafkaContinuousReader.scala | 232 +++++++++
.../sql/kafka010/KafkaContinuousWriter.scala | 119 +++++
.../spark/sql/kafka010/KafkaOffsetReader.scala | 21 +-
.../apache/spark/sql/kafka010/KafkaSource.scala | 17 +-
.../spark/sql/kafka010/KafkaSourceOffset.scala | 7 +-
.../sql/kafka010/KafkaSourceProvider.scala | 105 +++-
.../spark/sql/kafka010/KafkaWriteTask.scala | 71 +--
.../apache/spark/sql/kafka010/KafkaWriter.scala | 5 +-
.../sql/kafka010/KafkaContinuousSinkSuite.scala | 474 +++++++++++++++++++
.../kafka010/KafkaContinuousSourceSuite.scala | 96 ++++
.../sql/kafka010/KafkaContinuousTest.scala | 64 +++
.../spark/sql/kafka010/KafkaSourceSuite.scala | 470 +++++++++---------
.../org/apache/spark/sql/DataFrameReader.scala | 32 +-
.../org/apache/spark/sql/DataFrameWriter.scala | 25 +-
.../datasources/v2/WriteToDataSourceV2.scala | 8 +-
.../execution/streaming/StreamExecution.scala | 15 +-
.../ContinuousDataSourceRDDIter.scala | 3 +-
.../continuous/ContinuousExecution.scala | 67 +--
.../streaming/continuous/EpochCoordinator.scala | 21 +-
.../spark/sql/streaming/DataStreamWriter.scala | 26 +-
.../apache/spark/sql/streaming/StreamTest.scala | 36 +-
21 files changed, 1531 insertions(+), 383 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
new file mode 100644
index 0000000..9283795
--- /dev/null
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import java.{util => ju}
+
+import org.apache.kafka.clients.consumer.ConsumerRecord
+import org.apache.kafka.common.TopicPartition
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.kafka010.KafkaSource.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
+import org.apache.spark.sql.sources.v2.reader._
+import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * A [[ContinuousReader]] for data from kafka.
+ *
+ * @param offsetReader a reader used to get kafka offsets. Note that the actual data will be
+ * read by per-task consumers generated later.
+ * @param kafkaParams String params for per-task Kafka consumers.
+ * @param sourceOptions The [[org.apache.spark.sql.sources.v2.DataSourceV2Options]] params which
+ * are not Kafka consumer params.
+ * @param metadataPath Path to a directory this reader can use for writing metadata.
+ * @param initialOffsets The Kafka offsets to start reading data at.
+ * @param failOnDataLoss Flag indicating whether reading should fail in data loss
+ * scenarios, where some offsets after the specified initial ones can't be
+ * properly read.
+ */
+class KafkaContinuousReader(
+ offsetReader: KafkaOffsetReader,
+ kafkaParams: ju.Map[String, Object],
+ sourceOptions: Map[String, String],
+ metadataPath: String,
+ initialOffsets: KafkaOffsetRangeLimit,
+ failOnDataLoss: Boolean)
+ extends ContinuousReader with SupportsScanUnsafeRow with Logging {
+
+ private lazy val session = SparkSession.getActiveSession.get
+ private lazy val sc = session.sparkContext
+
+ // Initialized when creating read tasks. If this diverges from the partitions at the latest
+ // offsets, we need to reconfigure.
+ // Exposed outside this object only for unit tests.
+ private[sql] var knownPartitions: Set[TopicPartition] = _
+
+ override def readSchema: StructType = KafkaOffsetReader.kafkaSchema
+
+ private var offset: Offset = _
+ override def setOffset(start: ju.Optional[Offset]): Unit = {
+ offset = start.orElse {
+ val offsets = initialOffsets match {
+ case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets())
+ case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets())
+ case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss)
+ }
+ logInfo(s"Initial offsets: $offsets")
+ offsets
+ }
+ }
+
+ override def getStartOffset(): Offset = offset
+
+ override def deserializeOffset(json: String): Offset = {
+ KafkaSourceOffset(JsonUtils.partitionOffsets(json))
+ }
+
+ override def createUnsafeRowReadTasks(): ju.List[ReadTask[UnsafeRow]] = {
+ import scala.collection.JavaConverters._
+
+ val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset)
+
+ val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet
+ val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet)
+ val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq)
+
+ val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet)
+ if (deletedPartitions.nonEmpty) {
+ reportDataLoss(s"Some partitions were deleted: $deletedPartitions")
+ }
+
+ val startOffsets = newPartitionOffsets ++
+ oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_))
+ knownPartitions = startOffsets.keySet
+
+ startOffsets.toSeq.map {
+ case (topicPartition, start) =>
+ KafkaContinuousReadTask(
+ topicPartition, start, kafkaParams, failOnDataLoss)
+ .asInstanceOf[ReadTask[UnsafeRow]]
+ }.asJava
+ }
+
+ /** Stop this source and free any resources it has allocated. */
+ def stop(): Unit = synchronized {
+ offsetReader.close()
+ }
+
+ override def commit(end: Offset): Unit = {}
+
+ override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = {
+ val mergedMap = offsets.map {
+ case KafkaSourcePartitionOffset(p, o) => Map(p -> o)
+ }.reduce(_ ++ _)
+ KafkaSourceOffset(mergedMap)
+ }
+
+ override def needsReconfiguration(): Boolean = {
+ knownPartitions != null && offsetReader.fetchLatestOffsets().keySet != knownPartitions
+ }
+
+ override def toString(): String = s"KafkaSource[$offsetReader]"
+
+ /**
+ * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`.
+ * Otherwise, just log a warning.
+ */
+ private def reportDataLoss(message: String): Unit = {
+ if (failOnDataLoss) {
+ throw new IllegalStateException(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE")
+ } else {
+ logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE")
+ }
+ }
+}
+
+/**
+ * A read task for continuous Kafka processing. This will be serialized and transformed into a
+ * full reader on executors.
+ *
+ * @param topicPartition The (topic, partition) pair this task is responsible for.
+ * @param startOffset The offset to start reading from within the partition.
+ * @param kafkaParams Kafka consumer params to use.
+ * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets
+ * are skipped.
+ */
+case class KafkaContinuousReadTask(
+ topicPartition: TopicPartition,
+ startOffset: Long,
+ kafkaParams: ju.Map[String, Object],
+ failOnDataLoss: Boolean) extends ReadTask[UnsafeRow] {
+ override def createDataReader(): KafkaContinuousDataReader = {
+ new KafkaContinuousDataReader(topicPartition, startOffset, kafkaParams, failOnDataLoss)
+ }
+}
+
+/**
+ * A per-task data reader for continuous Kafka processing.
+ *
+ * @param topicPartition The (topic, partition) pair this data reader is responsible for.
+ * @param startOffset The offset to start reading from within the partition.
+ * @param kafkaParams Kafka consumer params to use.
+ * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets
+ * are skipped.
+ */
+class KafkaContinuousDataReader(
+ topicPartition: TopicPartition,
+ startOffset: Long,
+ kafkaParams: ju.Map[String, Object],
+ failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] {
+ private val topic = topicPartition.topic
+ private val kafkaPartition = topicPartition.partition
+ private val consumer = CachedKafkaConsumer.createUncached(topic, kafkaPartition, kafkaParams)
+
+ private val sharedRow = new UnsafeRow(7)
+ private val bufferHolder = new BufferHolder(sharedRow)
+ private val rowWriter = new UnsafeRowWriter(bufferHolder, 7)
+
+ private var nextKafkaOffset = startOffset
+ private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _
+
+ override def next(): Boolean = {
+ var r: ConsumerRecord[Array[Byte], Array[Byte]] = null
+ while (r == null) {
+ r = consumer.get(
+ nextKafkaOffset,
+ untilOffset = Long.MaxValue,
+ pollTimeoutMs = Long.MaxValue,
+ failOnDataLoss)
+ }
+ nextKafkaOffset = r.offset + 1
+ currentRecord = r
+ true
+ }
+
+ override def get(): UnsafeRow = {
+ bufferHolder.reset()
+
+ if (currentRecord.key == null) {
+ rowWriter.setNullAt(0)
+ } else {
+ rowWriter.write(0, currentRecord.key)
+ }
+ rowWriter.write(1, currentRecord.value)
+ rowWriter.write(2, UTF8String.fromString(currentRecord.topic))
+ rowWriter.write(3, currentRecord.partition)
+ rowWriter.write(4, currentRecord.offset)
+ rowWriter.write(5,
+ DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(currentRecord.timestamp)))
+ rowWriter.write(6, currentRecord.timestampType.id)
+ sharedRow.setTotalSize(bufferHolder.totalSize)
+ sharedRow
+ }
+
+ override def getOffset(): KafkaSourcePartitionOffset = {
+ KafkaSourcePartitionOffset(topicPartition, nextKafkaOffset)
+ }
+
+ override def close(): Unit = {
+ consumer.close()
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala
new file mode 100644
index 0000000..9843f46
--- /dev/null
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import org.apache.kafka.clients.producer.{Callback, ProducerRecord, RecordMetadata}
+import scala.collection.JavaConverters._
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection}
+import org.apache.spark.sql.kafka010.KafkaSourceProvider.{kafkaParamsForProducer, TOPIC_OPTION_KEY}
+import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery
+import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter
+import org.apache.spark.sql.sources.v2.writer._
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.{BinaryType, StringType, StructType}
+
+/**
+ * Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we
+ * don't need to really send one.
+ */
+case object KafkaWriterCommitMessage extends WriterCommitMessage
+
+/**
+ * A [[ContinuousWriter]] for Kafka writing. Responsible for generating the writer factory.
+ * @param topic The topic this writer is responsible for. If None, topic will be inferred from
+ * a `topic` field in the incoming data.
+ * @param producerParams Parameters for Kafka producers in each task.
+ * @param schema The schema of the input data.
+ */
+class KafkaContinuousWriter(
+ topic: Option[String], producerParams: Map[String, String], schema: StructType)
+ extends ContinuousWriter with SupportsWriteInternalRow {
+
+ validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic)
+
+ override def createInternalRowWriterFactory(): KafkaContinuousWriterFactory =
+ KafkaContinuousWriterFactory(topic, producerParams, schema)
+
+ override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
+ override def abort(messages: Array[WriterCommitMessage]): Unit = {}
+}
+
+/**
+ * A [[DataWriterFactory]] for Kafka writing. Will be serialized and sent to executors to generate
+ * the per-task data writers.
+ * @param topic The topic that should be written to. If None, topic will be inferred from
+ * a `topic` field in the incoming data.
+ * @param producerParams Parameters for Kafka producers in each task.
+ * @param schema The schema of the input data.
+ */
+case class KafkaContinuousWriterFactory(
+ topic: Option[String], producerParams: Map[String, String], schema: StructType)
+ extends DataWriterFactory[InternalRow] {
+
+ override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = {
+ new KafkaContinuousDataWriter(topic, producerParams, schema.toAttributes)
+ }
+}
+
+/**
+ * A [[DataWriter]] for Kafka writing. One data writer will be created in each partition to
+ * process incoming rows.
+ *
+ * @param targetTopic The topic that this data writer is targeting. If None, topic will be inferred
+ * from a `topic` field in the incoming data.
+ * @param producerParams Parameters to use for the Kafka producer.
+ * @param inputSchema The attributes in the input data.
+ */
+class KafkaContinuousDataWriter(
+ targetTopic: Option[String], producerParams: Map[String, String], inputSchema: Seq[Attribute])
+ extends KafkaRowWriter(inputSchema, targetTopic) with DataWriter[InternalRow] {
+ import scala.collection.JavaConverters._
+
+ private lazy val producer = CachedKafkaProducer.getOrCreate(
+ new java.util.HashMap[String, Object](producerParams.asJava))
+
+ def write(row: InternalRow): Unit = {
+ checkForErrors()
+ sendRow(row, producer)
+ }
+
+ def commit(): WriterCommitMessage = {
+ // Send is asynchronous, but we can't commit until all rows are actually in Kafka.
+ // This requires flushing and then checking that no callbacks produced errors.
+ // We also check for errors before to fail as soon as possible - the check is cheap.
+ checkForErrors()
+ producer.flush()
+ checkForErrors()
+ KafkaWriterCommitMessage
+ }
+
+ def abort(): Unit = {}
+
+ def close(): Unit = {
+ checkForErrors()
+ if (producer != null) {
+ producer.flush()
+ checkForErrors()
+ CachedKafkaProducer.close(new java.util.HashMap[String, Object](producerParams.asJava))
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala
index 3e65949..551641c 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala
@@ -117,10 +117,14 @@ private[kafka010] class KafkaOffsetReader(
* Resolves the specific offsets based on Kafka seek positions.
* This method resolves offset value -1 to the latest and -2 to the
* earliest Kafka seek position.
+ *
+ * @param partitionOffsets the specific offsets to resolve
+ * @param reportDataLoss callback to either report or log data loss depending on setting
*/
def fetchSpecificOffsets(
- partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] =
- runUninterruptibly {
+ partitionOffsets: Map[TopicPartition, Long],
+ reportDataLoss: String => Unit): KafkaSourceOffset = {
+ val fetched = runUninterruptibly {
withRetriesWithoutInterrupt {
// Poll to get the latest assigned partitions
consumer.poll(0)
@@ -145,6 +149,19 @@ private[kafka010] class KafkaOffsetReader(
}
}
+ partitionOffsets.foreach {
+ case (tp, off) if off != KafkaOffsetRangeLimit.LATEST &&
+ off != KafkaOffsetRangeLimit.EARLIEST =>
+ if (fetched(tp) != off) {
+ reportDataLoss(
+ s"startingOffsets for $tp was $off but consumer reset to ${fetched(tp)}")
+ }
+ case _ =>
+ // no real way to check that beginning or end is reasonable
+ }
+ KafkaSourceOffset(fetched)
+ }
+
/**
* Fetch the earliest offsets for the topic partitions that are indicated
* in the [[ConsumerStrategy]].
http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
index e9cff04..27da760 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
@@ -130,7 +130,7 @@ private[kafka010] class KafkaSource(
val offsets = startingOffsets match {
case EarliestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchEarliestOffsets())
case LatestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchLatestOffsets())
- case SpecificOffsetRangeLimit(p) => fetchAndVerify(p)
+ case SpecificOffsetRangeLimit(p) => kafkaReader.fetchSpecificOffsets(p, reportDataLoss)
}
metadataLog.add(0, offsets)
logInfo(s"Initial offsets: $offsets")
@@ -138,21 +138,6 @@ private[kafka010] class KafkaSource(
}.partitionToOffsets
}
- private def fetchAndVerify(specificOffsets: Map[TopicPartition, Long]) = {
- val result = kafkaReader.fetchSpecificOffsets(specificOffsets)
- specificOffsets.foreach {
- case (tp, off) if off != KafkaOffsetRangeLimit.LATEST &&
- off != KafkaOffsetRangeLimit.EARLIEST =>
- if (result(tp) != off) {
- reportDataLoss(
- s"startingOffsets for $tp was $off but consumer reset to ${result(tp)}")
- }
- case _ =>
- // no real way to check that beginning or end is reasonable
- }
- KafkaSourceOffset(result)
- }
-
private var currentPartitionOffsets: Option[Map[TopicPartition, Long]] = None
override def schema: StructType = KafkaOffsetReader.kafkaSchema
http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala
index b5da415..c82154c 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala
@@ -20,17 +20,22 @@ package org.apache.spark.sql.kafka010
import org.apache.kafka.common.TopicPartition
import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset}
+import org.apache.spark.sql.sources.v2.streaming.reader.{Offset => OffsetV2, PartitionOffset}
/**
* An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and
* their offsets.
*/
private[kafka010]
-case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends Offset {
+case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends OffsetV2 {
override val json = JsonUtils.partitionOffsets(partitionToOffsets)
}
+private[kafka010]
+case class KafkaSourcePartitionOffset(topicPartition: TopicPartition, partitionOffset: Long)
+ extends PartitionOffset
+
/** Companion object of the [[KafkaSourceOffset]] */
private[kafka010] object KafkaSourceOffset {
http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
index 3cb4d8c..3914370 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.kafka010
import java.{util => ju}
-import java.util.{Locale, UUID}
+import java.util.{Locale, Optional, UUID}
import scala.collection.JavaConverters._
@@ -27,9 +27,12 @@ import org.apache.kafka.clients.producer.ProducerConfig
import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer}
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext}
-import org.apache.spark.sql.execution.streaming.{Sink, Source}
+import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext}
+import org.apache.spark.sql.execution.streaming.{Offset, Sink, Source}
import org.apache.spark.sql.sources._
+import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options}
+import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport}
+import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
@@ -43,6 +46,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
with StreamSinkProvider
with RelationProvider
with CreatableRelationProvider
+ with ContinuousWriteSupport
+ with ContinuousReadSupport
with Logging {
import KafkaSourceProvider._
@@ -101,6 +106,43 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
failOnDataLoss(caseInsensitiveParams))
}
+ override def createContinuousReader(
+ schema: Optional[StructType],
+ metadataPath: String,
+ options: DataSourceV2Options): KafkaContinuousReader = {
+ val parameters = options.asMap().asScala.toMap
+ validateStreamOptions(parameters)
+ // Each running query should use its own group id. Otherwise, the query may be only assigned
+ // partial data since Kafka will assign partitions to multiple consumers having the same group
+ // id. Hence, we should generate a unique id for each query.
+ val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}"
+
+ val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
+ val specifiedKafkaParams =
+ parameters
+ .keySet
+ .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka."))
+ .map { k => k.drop(6).toString -> parameters(k) }
+ .toMap
+
+ val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams,
+ STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
+
+ val kafkaOffsetReader = new KafkaOffsetReader(
+ strategy(caseInsensitiveParams),
+ kafkaParamsForDriver(specifiedKafkaParams),
+ parameters,
+ driverGroupIdPrefix = s"$uniqueGroupId-driver")
+
+ new KafkaContinuousReader(
+ kafkaOffsetReader,
+ kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId),
+ parameters,
+ metadataPath,
+ startingStreamOffsets,
+ failOnDataLoss(caseInsensitiveParams))
+ }
+
/**
* Returns a new base relation with the given parameters.
*
@@ -181,26 +223,22 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
}
}
- private def kafkaParamsForProducer(parameters: Map[String, String]): Map[String, String] = {
- val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
- if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) {
- throw new IllegalArgumentException(
- s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys "
- + "are serialized with ByteArraySerializer.")
- }
+ override def createContinuousWriter(
+ queryId: String,
+ schema: StructType,
+ mode: OutputMode,
+ options: DataSourceV2Options): Optional[ContinuousWriter] = {
+ import scala.collection.JavaConverters._
- if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}"))
- {
- throw new IllegalArgumentException(
- s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as "
- + "value are serialized with ByteArraySerializer.")
- }
- parameters
- .keySet
- .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka."))
- .map { k => k.drop(6).toString -> parameters(k) }
- .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName,
- ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName)
+ val spark = SparkSession.getActiveSession.get
+ val topic = Option(options.get(TOPIC_OPTION_KEY).orElse(null)).map(_.trim)
+ // We convert the options argument from V2 -> Java map -> scala mutable -> scala immutable.
+ val producerParams = kafkaParamsForProducer(options.asMap.asScala.toMap)
+
+ KafkaWriter.validateQuery(
+ schema.toAttributes, new java.util.HashMap[String, Object](producerParams.asJava), topic)
+
+ Optional.of(new KafkaContinuousWriter(topic, producerParams, schema))
}
private def strategy(caseInsensitiveParams: Map[String, String]) =
@@ -450,4 +488,27 @@ private[kafka010] object KafkaSourceProvider extends Logging {
def build(): ju.Map[String, Object] = map
}
+
+ private[kafka010] def kafkaParamsForProducer(
+ parameters: Map[String, String]): Map[String, String] = {
+ val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
+ if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) {
+ throw new IllegalArgumentException(
+ s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys "
+ + "are serialized with ByteArraySerializer.")
+ }
+
+ if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}"))
+ {
+ throw new IllegalArgumentException(
+ s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as "
+ + "value are serialized with ByteArraySerializer.")
+ }
+ parameters
+ .keySet
+ .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka."))
+ .map { k => k.drop(6).toString -> parameters(k) }
+ .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName,
+ ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName)
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala
index 6fd333e..baa60fe 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala
@@ -33,10 +33,8 @@ import org.apache.spark.sql.types.{BinaryType, StringType}
private[kafka010] class KafkaWriteTask(
producerConfiguration: ju.Map[String, Object],
inputSchema: Seq[Attribute],
- topic: Option[String]) {
+ topic: Option[String]) extends KafkaRowWriter(inputSchema, topic) {
// used to synchronize with Kafka callbacks
- @volatile private var failedWrite: Exception = null
- private val projection = createProjection
private var producer: KafkaProducer[Array[Byte], Array[Byte]] = _
/**
@@ -46,23 +44,7 @@ private[kafka010] class KafkaWriteTask(
producer = CachedKafkaProducer.getOrCreate(producerConfiguration)
while (iterator.hasNext && failedWrite == null) {
val currentRow = iterator.next()
- val projectedRow = projection(currentRow)
- val topic = projectedRow.getUTF8String(0)
- val key = projectedRow.getBinary(1)
- val value = projectedRow.getBinary(2)
- if (topic == null) {
- throw new NullPointerException(s"null topic present in the data. Use the " +
- s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.")
- }
- val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value)
- val callback = new Callback() {
- override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = {
- if (failedWrite == null && e != null) {
- failedWrite = e
- }
- }
- }
- producer.send(record, callback)
+ sendRow(currentRow, producer)
}
}
@@ -74,8 +56,49 @@ private[kafka010] class KafkaWriteTask(
producer = null
}
}
+}
+
+private[kafka010] abstract class KafkaRowWriter(
+ inputSchema: Seq[Attribute], topic: Option[String]) {
+
+ // used to synchronize with Kafka callbacks
+ @volatile protected var failedWrite: Exception = _
+ protected val projection = createProjection
+
+ private val callback = new Callback() {
+ override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = {
+ if (failedWrite == null && e != null) {
+ failedWrite = e
+ }
+ }
+ }
- private def createProjection: UnsafeProjection = {
+ /**
+ * Send the specified row to the producer, with a callback that will save any exception
+ * to failedWrite. Note that send is asynchronous; subclasses must flush() their producer before
+ * assuming the row is in Kafka.
+ */
+ protected def sendRow(
+ row: InternalRow, producer: KafkaProducer[Array[Byte], Array[Byte]]): Unit = {
+ val projectedRow = projection(row)
+ val topic = projectedRow.getUTF8String(0)
+ val key = projectedRow.getBinary(1)
+ val value = projectedRow.getBinary(2)
+ if (topic == null) {
+ throw new NullPointerException(s"null topic present in the data. Use the " +
+ s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.")
+ }
+ val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value)
+ producer.send(record, callback)
+ }
+
+ protected def checkForErrors(): Unit = {
+ if (failedWrite != null) {
+ throw failedWrite
+ }
+ }
+
+ private def createProjection = {
val topicExpression = topic.map(Literal(_)).orElse {
inputSchema.find(_.name == KafkaWriter.TOPIC_ATTRIBUTE_NAME)
}.getOrElse {
@@ -112,11 +135,5 @@ private[kafka010] class KafkaWriteTask(
Seq(topicExpression, Cast(keyExpression, BinaryType),
Cast(valueExpression, BinaryType)), inputSchema)
}
-
- private def checkForErrors(): Unit = {
- if (failedWrite != null) {
- throw failedWrite
- }
- }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala
index 5e9ae35..15cd448 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala
@@ -43,10 +43,9 @@ private[kafka010] object KafkaWriter extends Logging {
override def toString: String = "KafkaWriter"
def validateQuery(
- queryExecution: QueryExecution,
+ schema: Seq[Attribute],
kafkaParameters: ju.Map[String, Object],
topic: Option[String] = None): Unit = {
- val schema = queryExecution.analyzed.output
schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse(
if (topic.isEmpty) {
throw new AnalysisException(s"topic option required when no " +
@@ -84,7 +83,7 @@ private[kafka010] object KafkaWriter extends Logging {
kafkaParameters: ju.Map[String, Object],
topic: Option[String] = None): Unit = {
val schema = queryExecution.analyzed.output
- validateQuery(queryExecution, kafkaParameters, topic)
+ validateQuery(schema, kafkaParameters, topic)
queryExecution.toRdd.foreachPartition { iter =>
val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic)
Utils.tryWithSafeFinally(block = writeTask.execute(iter))(
http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala
new file mode 100644
index 0000000..dfc97b1
--- /dev/null
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala
@@ -0,0 +1,474 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import java.util.Locale
+import java.util.concurrent.atomic.AtomicInteger
+
+import org.apache.kafka.clients.producer.ProducerConfig
+import org.apache.kafka.common.serialization.ByteArraySerializer
+import org.scalatest.time.SpanSugar._
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection}
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.streaming._
+import org.apache.spark.sql.types.{BinaryType, DataType}
+import org.apache.spark.util.Utils
+
+/**
+ * This is a temporary port of KafkaSinkSuite, since we do not yet have a V2 memory stream.
+ * Once we have one, this will be changed to a specialization of KafkaSinkSuite and we won't have
+ * to duplicate all the code.
+ */
+class KafkaContinuousSinkSuite extends KafkaContinuousTest {
+ import testImplicits._
+
+ override val streamingTimeout = 30.seconds
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ testUtils = new KafkaTestUtils(
+ withBrokerProps = Map("auto.create.topics.enable" -> "false"))
+ testUtils.setup()
+ }
+
+ override def afterAll(): Unit = {
+ if (testUtils != null) {
+ testUtils.teardown()
+ testUtils = null
+ }
+ super.afterAll()
+ }
+
+ test("streaming - write to kafka with topic field") {
+ val inputTopic = newTopic()
+ testUtils.createTopic(inputTopic, partitions = 1)
+
+ val input = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", inputTopic)
+ .option("startingOffsets", "earliest")
+ .load()
+
+ val topic = newTopic()
+ testUtils.createTopic(topic)
+
+ val writer = createKafkaWriter(
+ input.toDF(),
+ withTopic = None,
+ withOutputMode = Some(OutputMode.Append))(
+ withSelectExpr = s"'$topic' as topic", "value")
+
+ val reader = createKafkaReader(topic)
+ .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value")
+ .selectExpr("CAST(key as INT) key", "CAST(value as INT) value")
+ .as[(Int, Int)]
+ .map(_._2)
+
+ try {
+ testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
+ failAfter(streamingTimeout) {
+ writer.processAllAvailable()
+ }
+ checkDatasetUnorderly(reader, 1, 2, 3, 4, 5)
+ testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10"))
+ failAfter(streamingTimeout) {
+ writer.processAllAvailable()
+ }
+ checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
+ } finally {
+ writer.stop()
+ }
+ }
+
+ test("streaming - write w/o topic field, with topic option") {
+ val inputTopic = newTopic()
+ testUtils.createTopic(inputTopic, partitions = 1)
+
+ val input = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", inputTopic)
+ .option("startingOffsets", "earliest")
+ .load()
+
+ val topic = newTopic()
+ testUtils.createTopic(topic)
+
+ val writer = createKafkaWriter(
+ input.toDF(),
+ withTopic = Some(topic),
+ withOutputMode = Some(OutputMode.Append()))()
+
+ val reader = createKafkaReader(topic)
+ .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value")
+ .selectExpr("CAST(key as INT) key", "CAST(value as INT) value")
+ .as[(Int, Int)]
+ .map(_._2)
+
+ try {
+ testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
+ failAfter(streamingTimeout) {
+ writer.processAllAvailable()
+ }
+ checkDatasetUnorderly(reader, 1, 2, 3, 4, 5)
+ testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10"))
+ failAfter(streamingTimeout) {
+ writer.processAllAvailable()
+ }
+ checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
+ } finally {
+ writer.stop()
+ }
+ }
+
+ test("streaming - topic field and topic option") {
+ /* The purpose of this test is to ensure that the topic option
+ * overrides the topic field. We begin by writing some data that
+ * includes a topic field and value (e.g., 'foo') along with a topic
+ * option. Then when we read from the topic specified in the option
+ * we should see the data i.e., the data was written to the topic
+ * option, and not to the topic in the data e.g., foo
+ */
+ val inputTopic = newTopic()
+ testUtils.createTopic(inputTopic, partitions = 1)
+
+ val input = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", inputTopic)
+ .option("startingOffsets", "earliest")
+ .load()
+
+ val topic = newTopic()
+ testUtils.createTopic(topic)
+
+ val writer = createKafkaWriter(
+ input.toDF(),
+ withTopic = Some(topic),
+ withOutputMode = Some(OutputMode.Append()))(
+ withSelectExpr = "'foo' as topic", "CAST(value as STRING) value")
+
+ val reader = createKafkaReader(topic)
+ .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
+ .selectExpr("CAST(key AS INT)", "CAST(value AS INT)")
+ .as[(Int, Int)]
+ .map(_._2)
+
+ try {
+ testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
+ failAfter(streamingTimeout) {
+ writer.processAllAvailable()
+ }
+ checkDatasetUnorderly(reader, 1, 2, 3, 4, 5)
+ testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10"))
+ failAfter(streamingTimeout) {
+ writer.processAllAvailable()
+ }
+ checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
+ } finally {
+ writer.stop()
+ }
+ }
+
+ test("null topic attribute") {
+ val inputTopic = newTopic()
+ testUtils.createTopic(inputTopic, partitions = 1)
+
+ val input = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", inputTopic)
+ .option("startingOffsets", "earliest")
+ .load()
+ val topic = newTopic()
+ testUtils.createTopic(topic)
+
+ /* No topic field or topic option */
+ var writer: StreamingQuery = null
+ var ex: Exception = null
+ try {
+ ex = intercept[StreamingQueryException] {
+ writer = createKafkaWriter(input.toDF())(
+ withSelectExpr = "CAST(null as STRING) as topic", "value"
+ )
+ testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
+ writer.processAllAvailable()
+ }
+ } finally {
+ writer.stop()
+ }
+ assert(ex.getCause.getCause.getMessage
+ .toLowerCase(Locale.ROOT)
+ .contains("null topic present in the data."))
+ }
+
+ test("streaming - write data with bad schema") {
+ val inputTopic = newTopic()
+ testUtils.createTopic(inputTopic, partitions = 1)
+
+ val input = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", inputTopic)
+ .option("startingOffsets", "earliest")
+ .load()
+ val topic = newTopic()
+ testUtils.createTopic(topic)
+
+ /* No topic field or topic option */
+ var writer: StreamingQuery = null
+ var ex: Exception = null
+ try {
+ ex = intercept[StreamingQueryException] {
+ writer = createKafkaWriter(input.toDF())(
+ withSelectExpr = "value as key", "value"
+ )
+ testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
+ writer.processAllAvailable()
+ }
+ } finally {
+ writer.stop()
+ }
+ assert(ex.getMessage
+ .toLowerCase(Locale.ROOT)
+ .contains("topic option required when no 'topic' attribute is present"))
+
+ try {
+ /* No value field */
+ ex = intercept[StreamingQueryException] {
+ writer = createKafkaWriter(input.toDF())(
+ withSelectExpr = s"'$topic' as topic", "value as key"
+ )
+ testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
+ writer.processAllAvailable()
+ }
+ } finally {
+ writer.stop()
+ }
+ assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(
+ "required attribute 'value' not found"))
+ }
+
+ test("streaming - write data with valid schema but wrong types") {
+ val inputTopic = newTopic()
+ testUtils.createTopic(inputTopic, partitions = 1)
+
+ val input = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", inputTopic)
+ .option("startingOffsets", "earliest")
+ .load()
+ .selectExpr("CAST(value as STRING) value")
+ val topic = newTopic()
+ testUtils.createTopic(topic)
+
+ var writer: StreamingQuery = null
+ var ex: Exception = null
+ try {
+ /* topic field wrong type */
+ ex = intercept[StreamingQueryException] {
+ writer = createKafkaWriter(input.toDF())(
+ withSelectExpr = s"CAST('1' as INT) as topic", "value"
+ )
+ testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
+ writer.processAllAvailable()
+ }
+ } finally {
+ writer.stop()
+ }
+ assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("topic type must be a string"))
+
+ try {
+ /* value field wrong type */
+ ex = intercept[StreamingQueryException] {
+ writer = createKafkaWriter(input.toDF())(
+ withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as value"
+ )
+ testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
+ writer.processAllAvailable()
+ }
+ } finally {
+ writer.stop()
+ }
+ assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(
+ "value attribute type must be a string or binarytype"))
+
+ try {
+ ex = intercept[StreamingQueryException] {
+ /* key field wrong type */
+ writer = createKafkaWriter(input.toDF())(
+ withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as key", "value"
+ )
+ testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
+ writer.processAllAvailable()
+ }
+ } finally {
+ writer.stop()
+ }
+ assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(
+ "key attribute type must be a string or binarytype"))
+ }
+
+ test("streaming - write to non-existing topic") {
+ val inputTopic = newTopic()
+ testUtils.createTopic(inputTopic, partitions = 1)
+
+ val input = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", inputTopic)
+ .option("startingOffsets", "earliest")
+ .load()
+ val topic = newTopic()
+
+ var writer: StreamingQuery = null
+ var ex: Exception = null
+ try {
+ ex = intercept[StreamingQueryException] {
+ writer = createKafkaWriter(input.toDF(), withTopic = Some(topic))()
+ testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
+ eventually(timeout(streamingTimeout)) {
+ assert(writer.exception.isDefined)
+ }
+ throw writer.exception.get
+ }
+ } finally {
+ writer.stop()
+ }
+ assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("job aborted"))
+ }
+
+ test("streaming - exception on config serializer") {
+ val inputTopic = newTopic()
+ testUtils.createTopic(inputTopic, partitions = 1)
+ testUtils.sendMessages(inputTopic, Array("0"))
+
+ val input = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", inputTopic)
+ .load()
+ var writer: StreamingQuery = null
+ var ex: Exception = null
+ try {
+ ex = intercept[StreamingQueryException] {
+ writer = createKafkaWriter(
+ input.toDF(),
+ withOptions = Map("kafka.key.serializer" -> "foo"))()
+ writer.processAllAvailable()
+ }
+ assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(
+ "kafka option 'key.serializer' is not supported"))
+ } finally {
+ writer.stop()
+ }
+
+ try {
+ ex = intercept[StreamingQueryException] {
+ writer = createKafkaWriter(
+ input.toDF(),
+ withOptions = Map("kafka.value.serializer" -> "foo"))()
+ writer.processAllAvailable()
+ }
+ assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(
+ "kafka option 'value.serializer' is not supported"))
+ } finally {
+ writer.stop()
+ }
+ }
+
+ test("generic - write big data with small producer buffer") {
+ /* This test ensures that we understand the semantics of Kafka when
+ * is comes to blocking on a call to send when the send buffer is full.
+ * This test will configure the smallest possible producer buffer and
+ * indicate that we should block when it is full. Thus, no exception should
+ * be thrown in the case of a full buffer.
+ */
+ val topic = newTopic()
+ testUtils.createTopic(topic, 1)
+ val options = new java.util.HashMap[String, String]
+ options.put("bootstrap.servers", testUtils.brokerAddress)
+ options.put("buffer.memory", "16384") // min buffer size
+ options.put("block.on.buffer.full", "true")
+ options.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName)
+ options.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName)
+ val inputSchema = Seq(AttributeReference("value", BinaryType)())
+ val data = new Array[Byte](15000) // large value
+ val writeTask = new KafkaContinuousDataWriter(Some(topic), options.asScala.toMap, inputSchema)
+ try {
+ val fieldTypes: Array[DataType] = Array(BinaryType)
+ val converter = UnsafeProjection.create(fieldTypes)
+ val row = new SpecificInternalRow(fieldTypes)
+ row.update(0, data)
+ val iter = Seq.fill(1000)(converter.apply(row)).iterator
+ iter.foreach(writeTask.write(_))
+ writeTask.commit()
+ } finally {
+ writeTask.close()
+ }
+ }
+
+ private def createKafkaReader(topic: String): DataFrame = {
+ spark.read
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("startingOffsets", "earliest")
+ .option("endingOffsets", "latest")
+ .option("subscribe", topic)
+ .load()
+ }
+
+ private def createKafkaWriter(
+ input: DataFrame,
+ withTopic: Option[String] = None,
+ withOutputMode: Option[OutputMode] = None,
+ withOptions: Map[String, String] = Map[String, String]())
+ (withSelectExpr: String*): StreamingQuery = {
+ var stream: DataStreamWriter[Row] = null
+ val checkpointDir = Utils.createTempDir()
+ var df = input.toDF()
+ if (withSelectExpr.length > 0) {
+ df = df.selectExpr(withSelectExpr: _*)
+ }
+ stream = df.writeStream
+ .format("kafka")
+ .option("checkpointLocation", checkpointDir.getCanonicalPath)
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ // We need to reduce blocking time to efficiently test non-existent partition behavior.
+ .option("kafka.max.block.ms", "1000")
+ .trigger(Trigger.Continuous(1000))
+ .queryName("kafkaStream")
+ withTopic.foreach(stream.option("topic", _))
+ withOutputMode.foreach(stream.outputMode(_))
+ withOptions.foreach(opt => stream.option(opt._1, opt._2))
+ stream.start()
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala
new file mode 100644
index 0000000..b3dade4
--- /dev/null
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import java.util.Properties
+import java.util.concurrent.atomic.AtomicInteger
+
+import org.scalatest.time.SpanSugar._
+import scala.collection.mutable
+import scala.util.Random
+
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row}
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+import org.apache.spark.sql.execution.streaming.StreamExecution
+import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
+import org.apache.spark.sql.streaming.{StreamTest, Trigger}
+import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession}
+
+// Run tests in KafkaSourceSuiteBase in continuous execution mode.
+class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest
+
+class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest {
+ import testImplicits._
+
+ override val brokerProps = Map("auto.create.topics.enable" -> "false")
+
+ test("subscribing topic by pattern with topic deletions") {
+ val topicPrefix = newTopic()
+ val topic = topicPrefix + "-seems"
+ val topic2 = topicPrefix + "-bad"
+ testUtils.createTopic(topic, partitions = 5)
+ testUtils.sendMessages(topic, Array("-1"))
+ require(testUtils.getLatestOffsets(Set(topic)).size === 5)
+
+ val reader = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("kafka.metadata.max.age.ms", "1")
+ .option("subscribePattern", s"$topicPrefix-.*")
+ .option("failOnDataLoss", "false")
+
+ val kafka = reader.load()
+ .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
+ .as[(String, String)]
+ val mapped = kafka.map(kv => kv._2.toInt + 1)
+
+ testStream(mapped)(
+ makeSureGetOffsetCalled,
+ AddKafkaData(Set(topic), 1, 2, 3),
+ CheckAnswer(2, 3, 4),
+ Execute { query =>
+ testUtils.deleteTopic(topic)
+ testUtils.createTopic(topic2, partitions = 5)
+ eventually(timeout(streamingTimeout)) {
+ assert(
+ query.lastExecution.logical.collectFirst {
+ case DataSourceV2Relation(_, r: KafkaContinuousReader) => r
+ }.exists { r =>
+ // Ensure the new topic is present and the old topic is gone.
+ r.knownPartitions.exists(_.topic == topic2)
+ },
+ s"query never reconfigured to new topic $topic2")
+ }
+ },
+ AddKafkaData(Set(topic2), 4, 5, 6),
+ CheckAnswer(2, 3, 4, 5, 6, 7)
+ )
+ }
+}
+
+class KafkaContinuousSourceStressForDontFailOnDataLossSuite
+ extends KafkaSourceStressForDontFailOnDataLossSuite {
+ override protected def startStream(ds: Dataset[Int]) = {
+ ds.writeStream
+ .format("memory")
+ .queryName("memory")
+ .start()
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala
new file mode 100644
index 0000000..e713e66
--- /dev/null
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+import org.apache.spark.sql.execution.streaming.StreamExecution
+import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.sql.test.TestSparkSession
+
+// Trait to configure StreamTest for kafka continuous execution tests.
+trait KafkaContinuousTest extends KafkaSourceTest {
+ override val defaultTrigger = Trigger.Continuous(1000)
+ override val defaultUseV2Sink = true
+
+ // We need more than the default local[2] to be able to schedule all partitions simultaneously.
+ override protected def createSparkSession = new TestSparkSession(
+ new SparkContext(
+ "local[10]",
+ "continuous-stream-test-sql-context",
+ sparkConf.set("spark.sql.testkey", "true")))
+
+ // In addition to setting the partitions in Kafka, we have to wait until the query has
+ // reconfigured to the new count so the test framework can hook in properly.
+ override protected def setTopicPartitions(
+ topic: String, newCount: Int, query: StreamExecution) = {
+ testUtils.addPartitions(topic, newCount)
+ eventually(timeout(streamingTimeout)) {
+ assert(
+ query.lastExecution.logical.collectFirst {
+ case DataSourceV2Relation(_, r: KafkaContinuousReader) => r
+ }.exists(_.knownPartitions.size == newCount),
+ s"query never reconfigured to $newCount partitions")
+ }
+ }
+
+ test("ensure continuous stream is being used") {
+ val query = spark.readStream
+ .format("rate")
+ .option("numPartitions", "1")
+ .option("rowsPerSecond", "1")
+ .load()
+
+ testStream(query)(
+ Execute(q => assert(q.isInstanceOf[ContinuousExecution]))
+ )
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala
index 2034b9b..d66908f8 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala
@@ -34,11 +34,14 @@ import org.scalatest.concurrent.PatienceConfiguration.Timeout
import org.scalatest.time.SpanSugar._
import org.apache.spark.SparkContext
-import org.apache.spark.sql.ForeachWriter
+import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row}
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2Exec}
import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
+import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryWriter
import org.apache.spark.sql.functions.{count, window}
import org.apache.spark.sql.kafka010.KafkaSourceProvider._
-import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest}
+import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest, Trigger}
import org.apache.spark.sql.streaming.util.StreamManualClock
import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession}
import org.apache.spark.util.Utils
@@ -49,9 +52,11 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
override val streamingTimeout = 30.seconds
+ protected val brokerProps = Map[String, Object]()
+
override def beforeAll(): Unit = {
super.beforeAll()
- testUtils = new KafkaTestUtils
+ testUtils = new KafkaTestUtils(brokerProps)
testUtils.setup()
}
@@ -59,18 +64,25 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
if (testUtils != null) {
testUtils.teardown()
testUtils = null
- super.afterAll()
}
+ super.afterAll()
}
protected def makeSureGetOffsetCalled = AssertOnQuery { q =>
// Because KafkaSource's initialPartitionOffsets is set lazily, we need to make sure
- // its "getOffset" is called before pushing any data. Otherwise, because of the race contion,
+ // its "getOffset" is called before pushing any data. Otherwise, because of the race condition,
// we don't know which data should be fetched when `startingOffsets` is latest.
- q.processAllAvailable()
+ q match {
+ case c: ContinuousExecution => c.awaitEpoch(0)
+ case m: MicroBatchExecution => m.processAllAvailable()
+ }
true
}
+ protected def setTopicPartitions(topic: String, newCount: Int, query: StreamExecution) : Unit = {
+ testUtils.addPartitions(topic, newCount)
+ }
+
/**
* Add data to Kafka.
*
@@ -82,7 +94,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
message: String = "",
topicAction: (String, Option[Int]) => Unit = (_, _) => {}) extends AddData {
- override def addData(query: Option[StreamExecution]): (Source, Offset) = {
+ override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = {
if (query.get.isActive) {
// Make sure no Spark job is running when deleting a topic
query.get.processAllAvailable()
@@ -97,16 +109,18 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
topicAction(existingTopicPartitions._1, Some(existingTopicPartitions._2))
}
- // Read all topics again in case some topics are delete.
- val allTopics = testUtils.getAllTopicsAndPartitionSize().toMap.keys
require(
query.nonEmpty,
"Cannot add data when there is no query for finding the active kafka source")
val sources = query.get.logicalPlan.collect {
- case StreamingExecutionRelation(source, _) if source.isInstanceOf[KafkaSource] =>
- source.asInstanceOf[KafkaSource]
- }
+ case StreamingExecutionRelation(source: KafkaSource, _) => source
+ } ++ (query.get.lastExecution match {
+ case null => Seq()
+ case e => e.logical.collect {
+ case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader
+ }
+ })
if (sources.isEmpty) {
throw new Exception(
"Could not find Kafka source in the StreamExecution logical plan to add data to")
@@ -137,14 +151,158 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
override def toString: String =
s"AddKafkaData(topics = $topics, data = $data, message = $message)"
}
-}
+ private val topicId = new AtomicInteger(0)
+ protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}"
+}
-class KafkaSourceSuite extends KafkaSourceTest {
+class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase {
import testImplicits._
- private val topicId = new AtomicInteger(0)
+ test("(de)serialization of initial offsets") {
+ val topic = newTopic()
+ testUtils.createTopic(topic, partitions = 5)
+
+ val reader = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", topic)
+
+ testStream(reader.load)(
+ makeSureGetOffsetCalled,
+ StopStream,
+ StartStream(),
+ StopStream)
+ }
+
+ test("maxOffsetsPerTrigger") {
+ val topic = newTopic()
+ testUtils.createTopic(topic, partitions = 3)
+ testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0))
+ testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1))
+ testUtils.sendMessages(topic, Array("1"), Some(2))
+
+ val reader = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("kafka.metadata.max.age.ms", "1")
+ .option("maxOffsetsPerTrigger", 10)
+ .option("subscribe", topic)
+ .option("startingOffsets", "earliest")
+ val kafka = reader.load()
+ .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
+ .as[(String, String)]
+ val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt)
+
+ val clock = new StreamManualClock
+
+ val waitUntilBatchProcessed = AssertOnQuery { q =>
+ eventually(Timeout(streamingTimeout)) {
+ if (!q.exception.isDefined) {
+ assert(clock.isStreamWaitingAt(clock.getTimeMillis()))
+ }
+ }
+ if (q.exception.isDefined) {
+ throw q.exception.get
+ }
+ true
+ }
+
+ testStream(mapped)(
+ StartStream(ProcessingTime(100), clock),
+ waitUntilBatchProcessed,
+ // 1 from smallest, 1 from middle, 8 from biggest
+ CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107),
+ AdvanceManualClock(100),
+ waitUntilBatchProcessed,
+ // smallest now empty, 1 more from middle, 9 more from biggest
+ CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
+ 11, 108, 109, 110, 111, 112, 113, 114, 115, 116
+ ),
+ StopStream,
+ StartStream(ProcessingTime(100), clock),
+ waitUntilBatchProcessed,
+ // smallest now empty, 1 more from middle, 9 more from biggest
+ CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
+ 11, 108, 109, 110, 111, 112, 113, 114, 115, 116,
+ 12, 117, 118, 119, 120, 121, 122, 123, 124, 125
+ ),
+ AdvanceManualClock(100),
+ waitUntilBatchProcessed,
+ // smallest now empty, 1 more from middle, 9 more from biggest
+ CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
+ 11, 108, 109, 110, 111, 112, 113, 114, 115, 116,
+ 12, 117, 118, 119, 120, 121, 122, 123, 124, 125,
+ 13, 126, 127, 128, 129, 130, 131, 132, 133, 134
+ )
+ )
+ }
+
+ test("input row metrics") {
+ val topic = newTopic()
+ testUtils.createTopic(topic, partitions = 5)
+ testUtils.sendMessages(topic, Array("-1"))
+ require(testUtils.getLatestOffsets(Set(topic)).size === 5)
+
+ val kafka = spark
+ .readStream
+ .format("kafka")
+ .option("subscribe", topic)
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .load()
+ .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
+ .as[(String, String)]
+
+ val mapped = kafka.map(kv => kv._2.toInt + 1)
+ testStream(mapped)(
+ StartStream(trigger = ProcessingTime(1)),
+ makeSureGetOffsetCalled,
+ AddKafkaData(Set(topic), 1, 2, 3),
+ CheckAnswer(2, 3, 4),
+ AssertOnQuery { query =>
+ val recordsRead = query.recentProgress.map(_.numInputRows).sum
+ recordsRead == 3
+ }
+ )
+ }
+
+ test("subscribing topic by pattern with topic deletions") {
+ val topicPrefix = newTopic()
+ val topic = topicPrefix + "-seems"
+ val topic2 = topicPrefix + "-bad"
+ testUtils.createTopic(topic, partitions = 5)
+ testUtils.sendMessages(topic, Array("-1"))
+ require(testUtils.getLatestOffsets(Set(topic)).size === 5)
+
+ val reader = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("kafka.metadata.max.age.ms", "1")
+ .option("subscribePattern", s"$topicPrefix-.*")
+ .option("failOnDataLoss", "false")
+
+ val kafka = reader.load()
+ .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
+ .as[(String, String)]
+ val mapped = kafka.map(kv => kv._2.toInt + 1)
+
+ testStream(mapped)(
+ makeSureGetOffsetCalled,
+ AddKafkaData(Set(topic), 1, 2, 3),
+ CheckAnswer(2, 3, 4),
+ Assert {
+ testUtils.deleteTopic(topic)
+ testUtils.createTopic(topic2, partitions = 5)
+ true
+ },
+ AddKafkaData(Set(topic2), 4, 5, 6),
+ CheckAnswer(2, 3, 4, 5, 6, 7)
+ )
+ }
testWithUninterruptibleThread(
"deserialization of initial offset with Spark 2.1.0") {
@@ -237,86 +395,51 @@ class KafkaSourceSuite extends KafkaSourceTest {
}
}
- test("(de)serialization of initial offsets") {
- val topic = newTopic()
- testUtils.createTopic(topic, partitions = 64)
-
- val reader = spark
- .readStream
- .format("kafka")
- .option("kafka.bootstrap.servers", testUtils.brokerAddress)
- .option("subscribe", topic)
-
- testStream(reader.load)(
- makeSureGetOffsetCalled,
- StopStream,
- StartStream(),
- StopStream)
- }
-
- test("maxOffsetsPerTrigger") {
+ test("KafkaSource with watermark") {
+ val now = System.currentTimeMillis()
val topic = newTopic()
- testUtils.createTopic(topic, partitions = 3)
- testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0))
- testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1))
- testUtils.sendMessages(topic, Array("1"), Some(2))
+ testUtils.createTopic(newTopic(), partitions = 1)
+ testUtils.sendMessages(topic, Array(1).map(_.toString))
- val reader = spark
+ val kafka = spark
.readStream
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("kafka.metadata.max.age.ms", "1")
- .option("maxOffsetsPerTrigger", 10)
+ .option("startingOffsets", s"earliest")
.option("subscribe", topic)
- .option("startingOffsets", "earliest")
- val kafka = reader.load()
- .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
- .as[(String, String)]
- val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt)
-
- val clock = new StreamManualClock
+ .load()
- val waitUntilBatchProcessed = AssertOnQuery { q =>
- eventually(Timeout(streamingTimeout)) {
- if (!q.exception.isDefined) {
- assert(clock.isStreamWaitingAt(clock.getTimeMillis()))
- }
- }
- if (q.exception.isDefined) {
- throw q.exception.get
- }
- true
- }
+ val windowedAggregation = kafka
+ .withWatermark("timestamp", "10 seconds")
+ .groupBy(window($"timestamp", "5 seconds") as 'window)
+ .agg(count("*") as 'count)
+ .select($"window".getField("start") as 'window, $"count")
- testStream(mapped)(
- StartStream(ProcessingTime(100), clock),
- waitUntilBatchProcessed,
- // 1 from smallest, 1 from middle, 8 from biggest
- CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107),
- AdvanceManualClock(100),
- waitUntilBatchProcessed,
- // smallest now empty, 1 more from middle, 9 more from biggest
- CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
- 11, 108, 109, 110, 111, 112, 113, 114, 115, 116
- ),
- StopStream,
- StartStream(ProcessingTime(100), clock),
- waitUntilBatchProcessed,
- // smallest now empty, 1 more from middle, 9 more from biggest
- CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
- 11, 108, 109, 110, 111, 112, 113, 114, 115, 116,
- 12, 117, 118, 119, 120, 121, 122, 123, 124, 125
- ),
- AdvanceManualClock(100),
- waitUntilBatchProcessed,
- // smallest now empty, 1 more from middle, 9 more from biggest
- CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
- 11, 108, 109, 110, 111, 112, 113, 114, 115, 116,
- 12, 117, 118, 119, 120, 121, 122, 123, 124, 125,
- 13, 126, 127, 128, 129, 130, 131, 132, 133, 134
- )
- )
+ val query = windowedAggregation
+ .writeStream
+ .format("memory")
+ .outputMode("complete")
+ .queryName("kafkaWatermark")
+ .start()
+ query.processAllAvailable()
+ val rows = spark.table("kafkaWatermark").collect()
+ assert(rows.length === 1, s"Unexpected results: ${rows.toList}")
+ val row = rows(0)
+ // We cannot check the exact window start time as it depands on the time that messages were
+ // inserted by the producer. So here we just use a low bound to make sure the internal
+ // conversion works.
+ assert(
+ row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000,
+ s"Unexpected results: $row")
+ assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row")
+ query.stop()
}
+}
+
+class KafkaSourceSuiteBase extends KafkaSourceTest {
+
+ import testImplicits._
test("cannot stop Kafka stream") {
val topic = newTopic()
@@ -328,7 +451,7 @@ class KafkaSourceSuite extends KafkaSourceTest {
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("kafka.metadata.max.age.ms", "1")
- .option("subscribePattern", s"topic-.*")
+ .option("subscribePattern", s"$topic.*")
val kafka = reader.load()
.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
@@ -422,65 +545,6 @@ class KafkaSourceSuite extends KafkaSourceTest {
}
}
- test("subscribing topic by pattern with topic deletions") {
- val topicPrefix = newTopic()
- val topic = topicPrefix + "-seems"
- val topic2 = topicPrefix + "-bad"
- testUtils.createTopic(topic, partitions = 5)
- testUtils.sendMessages(topic, Array("-1"))
- require(testUtils.getLatestOffsets(Set(topic)).size === 5)
-
- val reader = spark
- .readStream
- .format("kafka")
- .option("kafka.bootstrap.servers", testUtils.brokerAddress)
- .option("kafka.metadata.max.age.ms", "1")
- .option("subscribePattern", s"$topicPrefix-.*")
- .option("failOnDataLoss", "false")
-
- val kafka = reader.load()
- .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
- .as[(String, String)]
- val mapped = kafka.map(kv => kv._2.toInt + 1)
-
- testStream(mapped)(
- makeSureGetOffsetCalled,
- AddKafkaData(Set(topic), 1, 2, 3),
- CheckAnswer(2, 3, 4),
- Assert {
- testUtils.deleteTopic(topic)
- testUtils.createTopic(topic2, partitions = 5)
- true
- },
- AddKafkaData(Set(topic2), 4, 5, 6),
- CheckAnswer(2, 3, 4, 5, 6, 7)
- )
- }
-
- test("starting offset is latest by default") {
- val topic = newTopic()
- testUtils.createTopic(topic, partitions = 5)
- testUtils.sendMessages(topic, Array("0"))
- require(testUtils.getLatestOffsets(Set(topic)).size === 5)
-
- val reader = spark
- .readStream
- .format("kafka")
- .option("kafka.bootstrap.servers", testUtils.brokerAddress)
- .option("subscribe", topic)
-
- val kafka = reader.load()
- .selectExpr("CAST(value AS STRING)")
- .as[String]
- val mapped = kafka.map(_.toInt)
-
- testStream(mapped)(
- makeSureGetOffsetCalled,
- AddKafkaData(Set(topic), 1, 2, 3),
- CheckAnswer(1, 2, 3) // should not have 0
- )
- }
-
test("bad source options") {
def testBadOptions(options: (String, String)*)(expectedMsgs: String*): Unit = {
val ex = intercept[IllegalArgumentException] {
@@ -540,34 +604,6 @@ class KafkaSourceSuite extends KafkaSourceTest {
testUnsupportedConfig("kafka.auto.offset.reset", "latest")
}
- test("input row metrics") {
- val topic = newTopic()
- testUtils.createTopic(topic, partitions = 5)
- testUtils.sendMessages(topic, Array("-1"))
- require(testUtils.getLatestOffsets(Set(topic)).size === 5)
-
- val kafka = spark
- .readStream
- .format("kafka")
- .option("subscribe", topic)
- .option("kafka.bootstrap.servers", testUtils.brokerAddress)
- .load()
- .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
- .as[(String, String)]
-
- val mapped = kafka.map(kv => kv._2.toInt + 1)
- testStream(mapped)(
- StartStream(trigger = ProcessingTime(1)),
- makeSureGetOffsetCalled,
- AddKafkaData(Set(topic), 1, 2, 3),
- CheckAnswer(2, 3, 4),
- AssertOnQuery { query =>
- val recordsRead = query.recentProgress.map(_.numInputRows).sum
- recordsRead == 3
- }
- )
- }
-
test("delete a topic when a Spark job is running") {
KafkaSourceSuite.collectedData.clear()
@@ -629,8 +665,6 @@ class KafkaSourceSuite extends KafkaSourceTest {
}
}
- private def newTopic(): String = s"topic-${topicId.getAndIncrement()}"
-
private def assignString(topic: String, partitions: Iterable[Int]): String = {
JsonUtils.partitions(partitions.map(p => new TopicPartition(topic, p)))
}
@@ -676,6 +710,10 @@ class KafkaSourceSuite extends KafkaSourceTest {
testStream(mapped)(
makeSureGetOffsetCalled,
+ Execute { q =>
+ // wait to reach the last offset in every partition
+ q.awaitOffset(0, KafkaSourceOffset(partitionOffsets.mapValues(_ => 3L)))
+ },
CheckAnswer(-20, -21, -22, 0, 1, 2, 11, 12, 22),
StopStream,
StartStream(),
@@ -706,6 +744,7 @@ class KafkaSourceSuite extends KafkaSourceTest {
.format("memory")
.outputMode("append")
.queryName("kafkaColumnTypes")
+ .trigger(defaultTrigger)
.start()
query.processAllAvailable()
val rows = spark.table("kafkaColumnTypes").collect()
@@ -723,47 +762,6 @@ class KafkaSourceSuite extends KafkaSourceTest {
query.stop()
}
- test("KafkaSource with watermark") {
- val now = System.currentTimeMillis()
- val topic = newTopic()
- testUtils.createTopic(newTopic(), partitions = 1)
- testUtils.sendMessages(topic, Array(1).map(_.toString))
-
- val kafka = spark
- .readStream
- .format("kafka")
- .option("kafka.bootstrap.servers", testUtils.brokerAddress)
- .option("kafka.metadata.max.age.ms", "1")
- .option("startingOffsets", s"earliest")
- .option("subscribe", topic)
- .load()
-
- val windowedAggregation = kafka
- .withWatermark("timestamp", "10 seconds")
- .groupBy(window($"timestamp", "5 seconds") as 'window)
- .agg(count("*") as 'count)
- .select($"window".getField("start") as 'window, $"count")
-
- val query = windowedAggregation
- .writeStream
- .format("memory")
- .outputMode("complete")
- .queryName("kafkaWatermark")
- .start()
- query.processAllAvailable()
- val rows = spark.table("kafkaWatermark").collect()
- assert(rows.length === 1, s"Unexpected results: ${rows.toList}")
- val row = rows(0)
- // We cannot check the exact window start time as it depands on the time that messages were
- // inserted by the producer. So here we just use a low bound to make sure the internal
- // conversion works.
- assert(
- row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000,
- s"Unexpected results: $row")
- assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row")
- query.stop()
- }
-
private def testFromLatestOffsets(
topic: String,
addPartitions: Boolean,
@@ -800,9 +798,7 @@ class KafkaSourceSuite extends KafkaSourceTest {
AddKafkaData(Set(topic), 7, 8),
CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9),
AssertOnQuery("Add partitions") { query: StreamExecution =>
- if (addPartitions) {
- testUtils.addPartitions(topic, 10)
- }
+ if (addPartitions) setTopicPartitions(topic, 10, query)
true
},
AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16),
@@ -843,9 +839,7 @@ class KafkaSourceSuite extends KafkaSourceTest {
StartStream(),
CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9),
AssertOnQuery("Add partitions") { query: StreamExecution =>
- if (addPartitions) {
- testUtils.addPartitions(topic, 10)
- }
+ if (addPartitions) setTopicPartitions(topic, 10, query)
true
},
AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16),
@@ -977,20 +971,8 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared
}
}
- test("stress test for failOnDataLoss=false") {
- val reader = spark
- .readStream
- .format("kafka")
- .option("kafka.bootstrap.servers", testUtils.brokerAddress)
- .option("kafka.metadata.max.age.ms", "1")
- .option("subscribePattern", "failOnDataLoss.*")
- .option("startingOffsets", "earliest")
- .option("failOnDataLoss", "false")
- .option("fetchOffset.retryIntervalMs", "3000")
- val kafka = reader.load()
- .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
- .as[(String, String)]
- val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] {
+ protected def startStream(ds: Dataset[Int]) = {
+ ds.writeStream.foreach(new ForeachWriter[Int] {
override def open(partitionId: Long, version: Long): Boolean = {
true
@@ -1004,6 +986,22 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared
override def close(errorOrNull: Throwable): Unit = {
}
}).start()
+ }
+
+ test("stress test for failOnDataLoss=false") {
+ val reader = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("kafka.metadata.max.age.ms", "1")
+ .option("subscribePattern", "failOnDataLoss.*")
+ .option("startingOffsets", "earliest")
+ .option("failOnDataLoss", "false")
+ .option("fetchOffset.retryIntervalMs", "3000")
+ val kafka = reader.load()
+ .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
+ .as[(String, String)]
+ val query = startStream(kafka.map(kv => kv._2.toInt))
val testTime = 1.minutes
val startTime = System.currentTimeMillis()
http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index e8d683a..b714a46 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -191,6 +191,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
ds = ds.asInstanceOf[DataSourceV2],
conf = sparkSession.sessionState.conf)).asJava)
+ // Streaming also uses the data source V2 API. So it may be that the data source implements
+ // v2, but has no v2 implementation for batch reads. In that case, we fall back to loading
+ // the dataframe as a v1 source.
val reader = (ds, userSpecifiedSchema) match {
case (ds: ReadSupportWithSchema, Some(schema)) =>
ds.createReader(schema, options)
@@ -208,23 +211,30 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
}
reader
- case _ =>
- throw new AnalysisException(s"$cls does not support data reading.")
+ case _ => null // fall back to v1
}
- Dataset.ofRows(sparkSession, DataSourceV2Relation(reader))
+ if (reader == null) {
+ loadV1Source(paths: _*)
+ } else {
+ Dataset.ofRows(sparkSession, DataSourceV2Relation(reader))
+ }
} else {
- // Code path for data source v1.
- sparkSession.baseRelationToDataFrame(
- DataSource.apply(
- sparkSession,
- paths = paths,
- userSpecifiedSchema = userSpecifiedSchema,
- className = source,
- options = extraOptions.toMap).resolveRelation())
+ loadV1Source(paths: _*)
}
}
+ private def loadV1Source(paths: String*) = {
+ // Code path for data source v1.
+ sparkSession.baseRelationToDataFrame(
+ DataSource.apply(
+ sparkSession,
+ paths = paths,
+ userSpecifiedSchema = userSpecifiedSchema,
+ className = source,
+ options = extraOptions.toMap).resolveRelation())
+ }
+
/**
* Construct a `DataFrame` representing the database table accessible via JDBC URL
* url named table and connection properties.
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org