You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by td...@apache.org on 2018/01/11 18:51:01 UTC

[1/2] spark git commit: [SPARK-22908] Add kafka source and sink for continuous processing.

Repository: spark
Updated Branches:
  refs/heads/master 0b2eefb67 -> 6f7aaed80


http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 3304f36..97f12ff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -255,17 +255,24 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
             }
           }
 
-        case _ => throw new AnalysisException(s"$cls does not support data writing.")
+        // Streaming also uses the data source V2 API. So it may be that the data source implements
+        // v2, but has no v2 implementation for batch writes. In that case, we fall back to saving
+        // as though it's a V1 source.
+        case _ => saveToV1Source()
       }
     } else {
-      // Code path for data source v1.
-      runCommand(df.sparkSession, "save") {
-        DataSource(
-          sparkSession = df.sparkSession,
-          className = source,
-          partitionColumns = partitioningColumns.getOrElse(Nil),
-          options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan))
-      }
+      saveToV1Source()
+    }
+  }
+
+  private def saveToV1Source(): Unit = {
+    // Code path for data source v1.
+    runCommand(df.sparkSession, "save") {
+      DataSource(
+        sparkSession = df.sparkSession,
+        className = source,
+        partitionColumns = partitioningColumns.getOrElse(Nil),
+        options = extraOptions.toMap).planForWriting(mode, AnalysisBarrier(df.logicalPlan))
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala
index f0bdf84..a4a857f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala
@@ -81,9 +81,11 @@ case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan)
         (index, message: WriterCommitMessage) => messages(index) = message
       )
 
-      logInfo(s"Data source writer $writer is committing.")
-      writer.commit(messages)
-      logInfo(s"Data source writer $writer committed.")
+      if (!writer.isInstanceOf[ContinuousWriter]) {
+        logInfo(s"Data source writer $writer is committing.")
+        writer.commit(messages)
+        logInfo(s"Data source writer $writer committed.")
+      }
     } catch {
       case _: InterruptedException if writer.isInstanceOf[ContinuousWriter] =>
         // Interruption is how continuous queries are ended, so accept and ignore the exception.

http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 24a8b00..cf27e1a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -142,7 +142,8 @@ abstract class StreamExecution(
 
   override val id: UUID = UUID.fromString(streamMetadata.id)
 
-  override val runId: UUID = UUID.randomUUID
+  override def runId: UUID = currentRunId
+  protected var currentRunId = UUID.randomUUID
 
   /**
    * Pretty identified string of printing in logs. Format is
@@ -418,11 +419,17 @@ abstract class StreamExecution(
    * Blocks the current thread until processing for data from the given `source` has reached at
    * least the given `Offset`. This method is intended for use primarily when writing tests.
    */
-  private[sql] def awaitOffset(source: BaseStreamingSource, newOffset: Offset): Unit = {
+  private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset): Unit = {
     assertAwaitThread()
     def notDone = {
       val localCommittedOffsets = committedOffsets
-      !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset
+      if (sources == null) {
+        // sources might not be initialized yet
+        false
+      } else {
+        val source = sources(sourceIndex)
+        !localCommittedOffsets.contains(source) || localCommittedOffsets(source) != newOffset
+      }
     }
 
     while (notDone) {
@@ -436,7 +443,7 @@ abstract class StreamExecution(
         awaitProgressLock.unlock()
       }
     }
-    logDebug(s"Unblocked at $newOffset for $source")
+    logDebug(s"Unblocked at $newOffset for ${sources(sourceIndex)}")
   }
 
   /** A flag to indicate that a batch has completed with no new data available. */

http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
index d79e4bd..e700aa4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
@@ -77,7 +77,6 @@ class ContinuousDataSourceRDD(
     dataReaderThread.start()
 
     context.addTaskCompletionListener(_ => {
-      reader.close()
       dataReaderThread.interrupt()
       epochPollExecutor.shutdown()
     })
@@ -201,6 +200,8 @@ class DataReaderThread(
         failedFlag.set(true)
         // Don't rethrow the exception in this thread. It's not needed, and the default Spark
         // exception handler will kill the executor.
+    } finally {
+      reader.close()
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
index 9657b5e..667410e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
@@ -17,7 +17,9 @@
 
 package org.apache.spark.sql.execution.streaming.continuous
 
+import java.util.UUID
 import java.util.concurrent.TimeUnit
+import java.util.function.UnaryOperator
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable.{ArrayBuffer, Map => MutableMap}
@@ -52,7 +54,7 @@ class ContinuousExecution(
     sparkSession, name, checkpointRoot, analyzedPlan, sink,
     trigger, triggerClock, outputMode, deleteCheckpointOnStop) {
 
-  @volatile protected var continuousSources: Seq[ContinuousReader] = Seq.empty
+  @volatile protected var continuousSources: Seq[ContinuousReader] = _
   override protected def sources: Seq[BaseStreamingSource] = continuousSources
 
   override lazy val logicalPlan: LogicalPlan = {
@@ -78,15 +80,17 @@ class ContinuousExecution(
   }
 
   override protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = {
-    do {
-      try {
-        runContinuous(sparkSessionForStream)
-      } catch {
-        case _: InterruptedException if state.get().equals(RECONFIGURING) =>
-          // swallow exception and run again
-          state.set(ACTIVE)
+    val stateUpdate = new UnaryOperator[State] {
+      override def apply(s: State) = s match {
+        // If we ended the query to reconfigure, reset the state to active.
+        case RECONFIGURING => ACTIVE
+        case _ => s
       }
-    } while (state.get() == ACTIVE)
+    }
+
+    do {
+      runContinuous(sparkSessionForStream)
+    } while (state.updateAndGet(stateUpdate) == ACTIVE)
   }
 
   /**
@@ -120,12 +124,16 @@ class ContinuousExecution(
         }
         committedOffsets = nextOffsets.toStreamProgress(sources)
 
-        // Forcibly align commit and offset logs by slicing off any spurious offset logs from
-        // a previous run. We can't allow commits to an epoch that a previous run reached but
-        // this run has not.
-        offsetLog.purgeAfter(latestEpochId)
+        // Get to an epoch ID that has definitely never been sent to a sink before. Since sink
+        // commit happens between offset log write and commit log write, this means an epoch ID
+        // which is not in the offset log.
+        val (latestOffsetEpoch, _) = offsetLog.getLatest().getOrElse {
+          throw new IllegalStateException(
+            s"Offset log had no latest element. This shouldn't be possible because nextOffsets is" +
+              s"an element.")
+        }
+        currentBatchId = latestOffsetEpoch + 1
 
-        currentBatchId = latestEpochId + 1
         logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets")
         nextOffsets
       case None =>
@@ -141,6 +149,7 @@ class ContinuousExecution(
    * @param sparkSessionForQuery Isolated [[SparkSession]] to run the continuous query with.
    */
   private def runContinuous(sparkSessionForQuery: SparkSession): Unit = {
+    currentRunId = UUID.randomUUID
     // A list of attributes that will need to be updated.
     val replacements = new ArrayBuffer[(Attribute, Attribute)]
     // Translate from continuous relation to the underlying data source.
@@ -225,13 +234,11 @@ class ContinuousExecution(
           triggerExecutor.execute(() => {
             startTrigger()
 
-            if (reader.needsReconfiguration()) {
-              state.set(RECONFIGURING)
+            if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, RECONFIGURING)) {
               stopSources()
               if (queryExecutionThread.isAlive) {
                 sparkSession.sparkContext.cancelJobGroup(runId.toString)
                 queryExecutionThread.interrupt()
-                // No need to join - this thread is about to end anyway.
               }
               false
             } else if (isActive) {
@@ -259,6 +266,7 @@ class ContinuousExecution(
           sparkSessionForQuery, lastExecution)(lastExecution.toRdd)
       }
     } finally {
+      epochEndpoint.askSync[Unit](StopContinuousExecutionWrites)
       SparkEnv.get.rpcEnv.stop(epochEndpoint)
 
       epochUpdateThread.interrupt()
@@ -273,17 +281,22 @@ class ContinuousExecution(
       epoch: Long, reader: ContinuousReader, partitionOffsets: Seq[PartitionOffset]): Unit = {
     assert(continuousSources.length == 1, "only one continuous source supported currently")
 
-    if (partitionOffsets.contains(null)) {
-      // If any offset is null, that means the corresponding partition hasn't seen any data yet, so
-      // there's nothing meaningful to add to the offset log.
-    }
     val globalOffset = reader.mergeOffsets(partitionOffsets.toArray)
-    synchronized {
-      if (queryExecutionThread.isAlive) {
-        offsetLog.add(epoch, OffsetSeq.fill(globalOffset))
-      } else {
-        return
-      }
+    val oldOffset = synchronized {
+      offsetLog.add(epoch, OffsetSeq.fill(globalOffset))
+      offsetLog.get(epoch - 1)
+    }
+
+    // If offset hasn't changed since last epoch, there's been no new data.
+    if (oldOffset.contains(OffsetSeq.fill(globalOffset))) {
+      noNewData = true
+    }
+
+    awaitProgressLock.lock()
+    try {
+      awaitProgressLockCondition.signalAll()
+    } finally {
+      awaitProgressLock.unlock()
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
index 98017c3..40dcbec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
@@ -39,6 +39,15 @@ private[continuous] sealed trait EpochCoordinatorMessage extends Serializable
  */
 private[sql] case object IncrementAndGetEpoch extends EpochCoordinatorMessage
 
+/**
+ * The RpcEndpoint stop() will wait to clear out the message queue before terminating the
+ * object. This can lead to a race condition where the query restarts at epoch n, a new
+ * EpochCoordinator starts at epoch n, and then the old epoch coordinator commits epoch n + 1.
+ * The framework doesn't provide a handle to wait on the message queue, so we use a synchronous
+ * message to stop any writes to the ContinuousExecution object.
+ */
+private[sql] case object StopContinuousExecutionWrites extends EpochCoordinatorMessage
+
 // Init messages
 /**
  * Set the reader and writer partition counts. Tasks may not be started until the coordinator
@@ -116,6 +125,8 @@ private[continuous] class EpochCoordinator(
     override val rpcEnv: RpcEnv)
   extends ThreadSafeRpcEndpoint with Logging {
 
+  private var queryWritesStopped: Boolean = false
+
   private var numReaderPartitions: Int = _
   private var numWriterPartitions: Int = _
 
@@ -147,12 +158,16 @@ private[continuous] class EpochCoordinator(
         partitionCommits.remove(k)
       }
       for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) {
-        partitionCommits.remove(k)
+        partitionOffsets.remove(k)
       }
     }
   }
 
   override def receive: PartialFunction[Any, Unit] = {
+    // If we just drop these messages, we won't do any writes to the query. The lame duck tasks
+    // won't shed errors or anything.
+    case _ if queryWritesStopped => ()
+
     case CommitPartitionEpoch(partitionId, epoch, message) =>
       logDebug(s"Got commit from partition $partitionId at epoch $epoch: $message")
       if (!partitionCommits.isDefinedAt((epoch, partitionId))) {
@@ -188,5 +203,9 @@ private[continuous] class EpochCoordinator(
     case SetWriterPartitions(numPartitions) =>
       numWriterPartitions = numPartitions
       context.reply(())
+
+    case StopContinuousExecutionWrites =>
+      queryWritesStopped = true
+      context.reply(())
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index db588ae..b5b4a05 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
 import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, MemorySinkV2}
+import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport
 
 /**
  * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems,
@@ -279,18 +280,29 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
         useTempCheckpointLocation = true,
         trigger = trigger)
     } else {
-      val dataSource =
-        DataSource(
-          df.sparkSession,
-          className = source,
-          options = extraOptions.toMap,
-          partitionColumns = normalizedParCols.getOrElse(Nil))
+      val sink = trigger match {
+        case _: ContinuousTrigger =>
+          val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf)
+          ds.newInstance() match {
+            case w: ContinuousWriteSupport => w
+            case _ => throw new AnalysisException(
+              s"Data source $source does not support continuous writing")
+          }
+        case _ =>
+          val ds = DataSource(
+            df.sparkSession,
+            className = source,
+            options = extraOptions.toMap,
+            partitionColumns = normalizedParCols.getOrElse(Nil))
+          ds.createSink(outputMode)
+      }
+
       df.sparkSession.sessionState.streamingQueryManager.startQuery(
         extraOptions.get("queryName"),
         extraOptions.get("checkpointLocation"),
         df,
         extraOptions.toMap,
-        dataSource.createSink(outputMode),
+        sink,
         outputMode,
         useTempCheckpointLocation = source == "console",
         recoverFromCheckpointLocation = true,

http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index d46461f..0762895 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -38,8 +38,9 @@ import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row}
 import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
 import org.apache.spark.sql.execution.streaming._
-import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch}
+import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger, EpochCoordinatorRef, IncrementAndGetEpoch}
 import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2
 import org.apache.spark.sql.execution.streaming.state.StateStore
 import org.apache.spark.sql.streaming.StreamingQueryListener._
@@ -80,6 +81,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
     StateStore.stop() // stop the state store maintenance thread and unload store providers
   }
 
+  protected val defaultTrigger = Trigger.ProcessingTime(0)
+  protected val defaultUseV2Sink = false
+
   /** How long to wait for an active stream to catch up when checking a result. */
   val streamingTimeout = 10.seconds
 
@@ -189,7 +193,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
 
   /** Starts the stream, resuming if data has already been processed. It must not be running. */
   case class StartStream(
-      trigger: Trigger = Trigger.ProcessingTime(0),
+      trigger: Trigger = defaultTrigger,
       triggerClock: Clock = new SystemClock,
       additionalConfs: Map[String, String] = Map.empty,
       checkpointLocation: String = null)
@@ -276,7 +280,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
   def testStream(
       _stream: Dataset[_],
       outputMode: OutputMode = OutputMode.Append,
-      useV2Sink: Boolean = false)(actions: StreamAction*): Unit = synchronized {
+      useV2Sink: Boolean = defaultUseV2Sink)(actions: StreamAction*): Unit = synchronized {
     import org.apache.spark.sql.streaming.util.StreamManualClock
 
     // `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently
@@ -403,18 +407,11 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
 
     def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = {
       verify(currentStream != null, "stream not running")
-      // Get the map of source index to the current source objects
-      val indexToSource = currentStream
-        .logicalPlan
-        .collect { case StreamingExecutionRelation(s, _) => s }
-        .zipWithIndex
-        .map(_.swap)
-        .toMap
 
       // Block until all data added has been processed for all the source
       awaiting.foreach { case (sourceIndex, offset) =>
         failAfter(streamingTimeout) {
-          currentStream.awaitOffset(indexToSource(sourceIndex), offset)
+          currentStream.awaitOffset(sourceIndex, offset)
         }
       }
 
@@ -473,6 +470,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
             // after starting the query.
             try {
               currentStream.awaitInitialization(streamingTimeout.toMillis)
+              currentStream match {
+                case s: ContinuousExecution => eventually("IncrementalExecution was not created") {
+                    s.lastExecution.executedPlan // will fail if lastExecution is null
+                  }
+                case _ =>
+              }
             } catch {
               case _: StreamingQueryException =>
                 // Ignore the exception. `StopStream` or `ExpectFailure` will catch it as well.
@@ -600,7 +603,10 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
 
               def findSourceIndex(plan: LogicalPlan): Option[Int] = {
                 plan
-                  .collect { case StreamingExecutionRelation(s, _) => s }
+                  .collect {
+                    case StreamingExecutionRelation(s, _) => s
+                    case DataSourceV2Relation(_, r) => r
+                  }
                   .zipWithIndex
                   .find(_._1 == source)
                   .map(_._2)
@@ -613,9 +619,13 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
                   findSourceIndex(query.logicalPlan)
                 }.orElse {
                   findSourceIndex(stream.logicalPlan)
+                }.orElse {
+                  queryToUse.flatMap { q =>
+                    findSourceIndex(q.lastExecution.logical)
+                  }
                 }.getOrElse {
                   throw new IllegalArgumentException(
-                    "Could find index of the source to which data was added")
+                    "Could not find index of the source to which data was added")
                 }
 
               // Store the expected offset of added data to wait for it later


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


[2/2] spark git commit: [SPARK-22908] Add kafka source and sink for continuous processing.

Posted by td...@apache.org.
[SPARK-22908] Add kafka source and sink for continuous processing.

## What changes were proposed in this pull request?

Add kafka source and sink for continuous processing. This involves two small changes to the execution engine:

* Bring data reader close() into the normal data reader thread to avoid thread safety issues.
* Fix up the semantics of the RECONFIGURING StreamExecution state. State updates are now atomic, and we don't have to deal with swallowing an exception.

## How was this patch tested?

new unit tests

Author: Jose Torres <jo...@databricks.com>

Closes #20096 from jose-torres/continuous-kafka.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6f7aaed8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6f7aaed8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6f7aaed8

Branch: refs/heads/master
Commit: 6f7aaed805070d29dcba32e04ca7a1f581fa54b9
Parents: 0b2eefb
Author: Jose Torres <jo...@databricks.com>
Authored: Thu Jan 11 10:52:12 2018 -0800
Committer: Tathagata Das <ta...@gmail.com>
Committed: Thu Jan 11 10:52:12 2018 -0800

----------------------------------------------------------------------
 .../sql/kafka010/KafkaContinuousReader.scala    | 232 +++++++++
 .../sql/kafka010/KafkaContinuousWriter.scala    | 119 +++++
 .../spark/sql/kafka010/KafkaOffsetReader.scala  |  21 +-
 .../apache/spark/sql/kafka010/KafkaSource.scala |  17 +-
 .../spark/sql/kafka010/KafkaSourceOffset.scala  |   7 +-
 .../sql/kafka010/KafkaSourceProvider.scala      | 105 +++-
 .../spark/sql/kafka010/KafkaWriteTask.scala     |  71 +--
 .../apache/spark/sql/kafka010/KafkaWriter.scala |   5 +-
 .../sql/kafka010/KafkaContinuousSinkSuite.scala | 474 +++++++++++++++++++
 .../kafka010/KafkaContinuousSourceSuite.scala   |  96 ++++
 .../sql/kafka010/KafkaContinuousTest.scala      |  64 +++
 .../spark/sql/kafka010/KafkaSourceSuite.scala   | 470 +++++++++---------
 .../org/apache/spark/sql/DataFrameReader.scala  |  32 +-
 .../org/apache/spark/sql/DataFrameWriter.scala  |  25 +-
 .../datasources/v2/WriteToDataSourceV2.scala    |   8 +-
 .../execution/streaming/StreamExecution.scala   |  15 +-
 .../ContinuousDataSourceRDDIter.scala           |   3 +-
 .../continuous/ContinuousExecution.scala        |  67 +--
 .../streaming/continuous/EpochCoordinator.scala |  21 +-
 .../spark/sql/streaming/DataStreamWriter.scala  |  26 +-
 .../apache/spark/sql/streaming/StreamTest.scala |  36 +-
 21 files changed, 1531 insertions(+), 383 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
new file mode 100644
index 0000000..9283795
--- /dev/null
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousReader.scala
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import java.{util => ju}
+
+import org.apache.kafka.clients.consumer.ConsumerRecord
+import org.apache.kafka.common.TopicPartition
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.kafka010.KafkaSource.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
+import org.apache.spark.sql.sources.v2.reader._
+import org.apache.spark.sql.sources.v2.streaming.reader.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * A [[ContinuousReader]] for data from kafka.
+ *
+ * @param offsetReader  a reader used to get kafka offsets. Note that the actual data will be
+ *                      read by per-task consumers generated later.
+ * @param kafkaParams   String params for per-task Kafka consumers.
+ * @param sourceOptions The [[org.apache.spark.sql.sources.v2.DataSourceV2Options]] params which
+ *                      are not Kafka consumer params.
+ * @param metadataPath Path to a directory this reader can use for writing metadata.
+ * @param initialOffsets The Kafka offsets to start reading data at.
+ * @param failOnDataLoss Flag indicating whether reading should fail in data loss
+ *                       scenarios, where some offsets after the specified initial ones can't be
+ *                       properly read.
+ */
+class KafkaContinuousReader(
+    offsetReader: KafkaOffsetReader,
+    kafkaParams: ju.Map[String, Object],
+    sourceOptions: Map[String, String],
+    metadataPath: String,
+    initialOffsets: KafkaOffsetRangeLimit,
+    failOnDataLoss: Boolean)
+  extends ContinuousReader with SupportsScanUnsafeRow with Logging {
+
+  private lazy val session = SparkSession.getActiveSession.get
+  private lazy val sc = session.sparkContext
+
+  // Initialized when creating read tasks. If this diverges from the partitions at the latest
+  // offsets, we need to reconfigure.
+  // Exposed outside this object only for unit tests.
+  private[sql] var knownPartitions: Set[TopicPartition] = _
+
+  override def readSchema: StructType = KafkaOffsetReader.kafkaSchema
+
+  private var offset: Offset = _
+  override def setOffset(start: ju.Optional[Offset]): Unit = {
+    offset = start.orElse {
+      val offsets = initialOffsets match {
+        case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets())
+        case LatestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchLatestOffsets())
+        case SpecificOffsetRangeLimit(p) => offsetReader.fetchSpecificOffsets(p, reportDataLoss)
+      }
+      logInfo(s"Initial offsets: $offsets")
+      offsets
+    }
+  }
+
+  override def getStartOffset(): Offset = offset
+
+  override def deserializeOffset(json: String): Offset = {
+    KafkaSourceOffset(JsonUtils.partitionOffsets(json))
+  }
+
+  override def createUnsafeRowReadTasks(): ju.List[ReadTask[UnsafeRow]] = {
+    import scala.collection.JavaConverters._
+
+    val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset)
+
+    val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet
+    val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet)
+    val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq)
+
+    val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet)
+    if (deletedPartitions.nonEmpty) {
+      reportDataLoss(s"Some partitions were deleted: $deletedPartitions")
+    }
+
+    val startOffsets = newPartitionOffsets ++
+      oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_))
+    knownPartitions = startOffsets.keySet
+
+    startOffsets.toSeq.map {
+      case (topicPartition, start) =>
+        KafkaContinuousReadTask(
+          topicPartition, start, kafkaParams, failOnDataLoss)
+          .asInstanceOf[ReadTask[UnsafeRow]]
+    }.asJava
+  }
+
+  /** Stop this source and free any resources it has allocated. */
+  def stop(): Unit = synchronized {
+    offsetReader.close()
+  }
+
+  override def commit(end: Offset): Unit = {}
+
+  override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = {
+    val mergedMap = offsets.map {
+      case KafkaSourcePartitionOffset(p, o) => Map(p -> o)
+    }.reduce(_ ++ _)
+    KafkaSourceOffset(mergedMap)
+  }
+
+  override def needsReconfiguration(): Boolean = {
+    knownPartitions != null && offsetReader.fetchLatestOffsets().keySet != knownPartitions
+  }
+
+  override def toString(): String = s"KafkaSource[$offsetReader]"
+
+  /**
+   * If `failOnDataLoss` is true, this method will throw an `IllegalStateException`.
+   * Otherwise, just log a warning.
+   */
+  private def reportDataLoss(message: String): Unit = {
+    if (failOnDataLoss) {
+      throw new IllegalStateException(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE")
+    } else {
+      logWarning(message + s". $INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE")
+    }
+  }
+}
+
+/**
+ * A read task for continuous Kafka processing. This will be serialized and transformed into a
+ * full reader on executors.
+ *
+ * @param topicPartition The (topic, partition) pair this task is responsible for.
+ * @param startOffset The offset to start reading from within the partition.
+ * @param kafkaParams Kafka consumer params to use.
+ * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets
+ *                       are skipped.
+ */
+case class KafkaContinuousReadTask(
+    topicPartition: TopicPartition,
+    startOffset: Long,
+    kafkaParams: ju.Map[String, Object],
+    failOnDataLoss: Boolean) extends ReadTask[UnsafeRow] {
+  override def createDataReader(): KafkaContinuousDataReader = {
+    new KafkaContinuousDataReader(topicPartition, startOffset, kafkaParams, failOnDataLoss)
+  }
+}
+
+/**
+ * A per-task data reader for continuous Kafka processing.
+ *
+ * @param topicPartition The (topic, partition) pair this data reader is responsible for.
+ * @param startOffset The offset to start reading from within the partition.
+ * @param kafkaParams Kafka consumer params to use.
+ * @param failOnDataLoss Flag indicating whether data reader should fail if some offsets
+ *                       are skipped.
+ */
+class KafkaContinuousDataReader(
+    topicPartition: TopicPartition,
+    startOffset: Long,
+    kafkaParams: ju.Map[String, Object],
+    failOnDataLoss: Boolean) extends ContinuousDataReader[UnsafeRow] {
+  private val topic = topicPartition.topic
+  private val kafkaPartition = topicPartition.partition
+  private val consumer = CachedKafkaConsumer.createUncached(topic, kafkaPartition, kafkaParams)
+
+  private val sharedRow = new UnsafeRow(7)
+  private val bufferHolder = new BufferHolder(sharedRow)
+  private val rowWriter = new UnsafeRowWriter(bufferHolder, 7)
+
+  private var nextKafkaOffset = startOffset
+  private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _
+
+  override def next(): Boolean = {
+    var r: ConsumerRecord[Array[Byte], Array[Byte]] = null
+    while (r == null) {
+      r = consumer.get(
+        nextKafkaOffset,
+        untilOffset = Long.MaxValue,
+        pollTimeoutMs = Long.MaxValue,
+        failOnDataLoss)
+    }
+    nextKafkaOffset = r.offset + 1
+    currentRecord = r
+    true
+  }
+
+  override def get(): UnsafeRow = {
+    bufferHolder.reset()
+
+    if (currentRecord.key == null) {
+      rowWriter.setNullAt(0)
+    } else {
+      rowWriter.write(0, currentRecord.key)
+    }
+    rowWriter.write(1, currentRecord.value)
+    rowWriter.write(2, UTF8String.fromString(currentRecord.topic))
+    rowWriter.write(3, currentRecord.partition)
+    rowWriter.write(4, currentRecord.offset)
+    rowWriter.write(5,
+      DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(currentRecord.timestamp)))
+    rowWriter.write(6, currentRecord.timestampType.id)
+    sharedRow.setTotalSize(bufferHolder.totalSize)
+    sharedRow
+  }
+
+  override def getOffset(): KafkaSourcePartitionOffset = {
+    KafkaSourcePartitionOffset(topicPartition, nextKafkaOffset)
+  }
+
+  override def close(): Unit = {
+    consumer.close()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala
new file mode 100644
index 0000000..9843f46
--- /dev/null
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousWriter.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import org.apache.kafka.clients.producer.{Callback, ProducerRecord, RecordMetadata}
+import scala.collection.JavaConverters._
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{Row, SparkSession}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection}
+import org.apache.spark.sql.kafka010.KafkaSourceProvider.{kafkaParamsForProducer, TOPIC_OPTION_KEY}
+import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery
+import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter
+import org.apache.spark.sql.sources.v2.writer._
+import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.types.{BinaryType, StringType, StructType}
+
+/**
+ * Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we
+ * don't need to really send one.
+ */
+case object KafkaWriterCommitMessage extends WriterCommitMessage
+
+/**
+ * A [[ContinuousWriter]] for Kafka writing. Responsible for generating the writer factory.
+ * @param topic The topic this writer is responsible for. If None, topic will be inferred from
+ *              a `topic` field in the incoming data.
+ * @param producerParams Parameters for Kafka producers in each task.
+ * @param schema The schema of the input data.
+ */
+class KafkaContinuousWriter(
+    topic: Option[String], producerParams: Map[String, String], schema: StructType)
+  extends ContinuousWriter with SupportsWriteInternalRow {
+
+  validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic)
+
+  override def createInternalRowWriterFactory(): KafkaContinuousWriterFactory =
+    KafkaContinuousWriterFactory(topic, producerParams, schema)
+
+  override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
+  override def abort(messages: Array[WriterCommitMessage]): Unit = {}
+}
+
+/**
+ * A [[DataWriterFactory]] for Kafka writing. Will be serialized and sent to executors to generate
+ * the per-task data writers.
+ * @param topic The topic that should be written to. If None, topic will be inferred from
+ *              a `topic` field in the incoming data.
+ * @param producerParams Parameters for Kafka producers in each task.
+ * @param schema The schema of the input data.
+ */
+case class KafkaContinuousWriterFactory(
+    topic: Option[String], producerParams: Map[String, String], schema: StructType)
+  extends DataWriterFactory[InternalRow] {
+
+  override def createDataWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = {
+    new KafkaContinuousDataWriter(topic, producerParams, schema.toAttributes)
+  }
+}
+
+/**
+ * A [[DataWriter]] for Kafka writing. One data writer will be created in each partition to
+ * process incoming rows.
+ *
+ * @param targetTopic The topic that this data writer is targeting. If None, topic will be inferred
+ *                    from a `topic` field in the incoming data.
+ * @param producerParams Parameters to use for the Kafka producer.
+ * @param inputSchema The attributes in the input data.
+ */
+class KafkaContinuousDataWriter(
+    targetTopic: Option[String], producerParams: Map[String, String], inputSchema: Seq[Attribute])
+  extends KafkaRowWriter(inputSchema, targetTopic) with DataWriter[InternalRow] {
+  import scala.collection.JavaConverters._
+
+  private lazy val producer = CachedKafkaProducer.getOrCreate(
+    new java.util.HashMap[String, Object](producerParams.asJava))
+
+  def write(row: InternalRow): Unit = {
+    checkForErrors()
+    sendRow(row, producer)
+  }
+
+  def commit(): WriterCommitMessage = {
+    // Send is asynchronous, but we can't commit until all rows are actually in Kafka.
+    // This requires flushing and then checking that no callbacks produced errors.
+    // We also check for errors before to fail as soon as possible - the check is cheap.
+    checkForErrors()
+    producer.flush()
+    checkForErrors()
+    KafkaWriterCommitMessage
+  }
+
+  def abort(): Unit = {}
+
+  def close(): Unit = {
+    checkForErrors()
+    if (producer != null) {
+      producer.flush()
+      checkForErrors()
+      CachedKafkaProducer.close(new java.util.HashMap[String, Object](producerParams.asJava))
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala
index 3e65949..551641c 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala
@@ -117,10 +117,14 @@ private[kafka010] class KafkaOffsetReader(
    * Resolves the specific offsets based on Kafka seek positions.
    * This method resolves offset value -1 to the latest and -2 to the
    * earliest Kafka seek position.
+   *
+   * @param partitionOffsets the specific offsets to resolve
+   * @param reportDataLoss callback to either report or log data loss depending on setting
    */
   def fetchSpecificOffsets(
-      partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] =
-    runUninterruptibly {
+      partitionOffsets: Map[TopicPartition, Long],
+      reportDataLoss: String => Unit): KafkaSourceOffset = {
+    val fetched = runUninterruptibly {
       withRetriesWithoutInterrupt {
         // Poll to get the latest assigned partitions
         consumer.poll(0)
@@ -145,6 +149,19 @@ private[kafka010] class KafkaOffsetReader(
       }
     }
 
+    partitionOffsets.foreach {
+      case (tp, off) if off != KafkaOffsetRangeLimit.LATEST &&
+        off != KafkaOffsetRangeLimit.EARLIEST =>
+        if (fetched(tp) != off) {
+          reportDataLoss(
+            s"startingOffsets for $tp was $off but consumer reset to ${fetched(tp)}")
+        }
+      case _ =>
+        // no real way to check that beginning or end is reasonable
+    }
+    KafkaSourceOffset(fetched)
+  }
+
   /**
    * Fetch the earliest offsets for the topic partitions that are indicated
    * in the [[ConsumerStrategy]].

http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
index e9cff04..27da760 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSource.scala
@@ -130,7 +130,7 @@ private[kafka010] class KafkaSource(
       val offsets = startingOffsets match {
         case EarliestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchEarliestOffsets())
         case LatestOffsetRangeLimit => KafkaSourceOffset(kafkaReader.fetchLatestOffsets())
-        case SpecificOffsetRangeLimit(p) => fetchAndVerify(p)
+        case SpecificOffsetRangeLimit(p) => kafkaReader.fetchSpecificOffsets(p, reportDataLoss)
       }
       metadataLog.add(0, offsets)
       logInfo(s"Initial offsets: $offsets")
@@ -138,21 +138,6 @@ private[kafka010] class KafkaSource(
     }.partitionToOffsets
   }
 
-  private def fetchAndVerify(specificOffsets: Map[TopicPartition, Long]) = {
-    val result = kafkaReader.fetchSpecificOffsets(specificOffsets)
-    specificOffsets.foreach {
-      case (tp, off) if off != KafkaOffsetRangeLimit.LATEST &&
-          off != KafkaOffsetRangeLimit.EARLIEST =>
-        if (result(tp) != off) {
-          reportDataLoss(
-            s"startingOffsets for $tp was $off but consumer reset to ${result(tp)}")
-        }
-      case _ =>
-      // no real way to check that beginning or end is reasonable
-    }
-    KafkaSourceOffset(result)
-  }
-
   private var currentPartitionOffsets: Option[Map[TopicPartition, Long]] = None
 
   override def schema: StructType = KafkaOffsetReader.kafkaSchema

http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala
index b5da415..c82154c 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala
@@ -20,17 +20,22 @@ package org.apache.spark.sql.kafka010
 import org.apache.kafka.common.TopicPartition
 
 import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset}
+import org.apache.spark.sql.sources.v2.streaming.reader.{Offset => OffsetV2, PartitionOffset}
 
 /**
  * An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and
  * their offsets.
  */
 private[kafka010]
-case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends Offset {
+case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends OffsetV2 {
 
   override val json = JsonUtils.partitionOffsets(partitionToOffsets)
 }
 
+private[kafka010]
+case class KafkaSourcePartitionOffset(topicPartition: TopicPartition, partitionOffset: Long)
+  extends PartitionOffset
+
 /** Companion object of the [[KafkaSourceOffset]] */
 private[kafka010] object KafkaSourceOffset {
 

http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
index 3cb4d8c..3914370 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.kafka010
 
 import java.{util => ju}
-import java.util.{Locale, UUID}
+import java.util.{Locale, Optional, UUID}
 
 import scala.collection.JavaConverters._
 
@@ -27,9 +27,12 @@ import org.apache.kafka.clients.producer.ProducerConfig
 import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer}
 
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext}
-import org.apache.spark.sql.execution.streaming.{Sink, Source}
+import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SparkSession, SQLContext}
+import org.apache.spark.sql.execution.streaming.{Offset, Sink, Source}
 import org.apache.spark.sql.sources._
+import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options}
+import org.apache.spark.sql.sources.v2.streaming.{ContinuousReadSupport, ContinuousWriteSupport}
+import org.apache.spark.sql.sources.v2.streaming.writer.ContinuousWriter
 import org.apache.spark.sql.streaming.OutputMode
 import org.apache.spark.sql.types.StructType
 
@@ -43,6 +46,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
     with StreamSinkProvider
     with RelationProvider
     with CreatableRelationProvider
+    with ContinuousWriteSupport
+    with ContinuousReadSupport
     with Logging {
   import KafkaSourceProvider._
 
@@ -101,6 +106,43 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
       failOnDataLoss(caseInsensitiveParams))
   }
 
+  override def createContinuousReader(
+      schema: Optional[StructType],
+      metadataPath: String,
+      options: DataSourceV2Options): KafkaContinuousReader = {
+    val parameters = options.asMap().asScala.toMap
+    validateStreamOptions(parameters)
+    // Each running query should use its own group id. Otherwise, the query may be only assigned
+    // partial data since Kafka will assign partitions to multiple consumers having the same group
+    // id. Hence, we should generate a unique id for each query.
+    val uniqueGroupId = s"spark-kafka-source-${UUID.randomUUID}-${metadataPath.hashCode}"
+
+    val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
+    val specifiedKafkaParams =
+      parameters
+        .keySet
+        .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka."))
+        .map { k => k.drop(6).toString -> parameters(k) }
+        .toMap
+
+    val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams,
+      STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
+
+    val kafkaOffsetReader = new KafkaOffsetReader(
+      strategy(caseInsensitiveParams),
+      kafkaParamsForDriver(specifiedKafkaParams),
+      parameters,
+      driverGroupIdPrefix = s"$uniqueGroupId-driver")
+
+    new KafkaContinuousReader(
+      kafkaOffsetReader,
+      kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId),
+      parameters,
+      metadataPath,
+      startingStreamOffsets,
+      failOnDataLoss(caseInsensitiveParams))
+  }
+
   /**
    * Returns a new base relation with the given parameters.
    *
@@ -181,26 +223,22 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
     }
   }
 
-  private def kafkaParamsForProducer(parameters: Map[String, String]): Map[String, String] = {
-    val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
-    if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) {
-      throw new IllegalArgumentException(
-        s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys "
-          + "are serialized with ByteArraySerializer.")
-    }
+  override def createContinuousWriter(
+      queryId: String,
+      schema: StructType,
+      mode: OutputMode,
+      options: DataSourceV2Options): Optional[ContinuousWriter] = {
+    import scala.collection.JavaConverters._
 
-    if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}"))
-    {
-      throw new IllegalArgumentException(
-        s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as "
-          + "value are serialized with ByteArraySerializer.")
-    }
-    parameters
-      .keySet
-      .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka."))
-      .map { k => k.drop(6).toString -> parameters(k) }
-      .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName,
-        ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName)
+    val spark = SparkSession.getActiveSession.get
+    val topic = Option(options.get(TOPIC_OPTION_KEY).orElse(null)).map(_.trim)
+    // We convert the options argument from V2 -> Java map -> scala mutable -> scala immutable.
+    val producerParams = kafkaParamsForProducer(options.asMap.asScala.toMap)
+
+    KafkaWriter.validateQuery(
+      schema.toAttributes, new java.util.HashMap[String, Object](producerParams.asJava), topic)
+
+    Optional.of(new KafkaContinuousWriter(topic, producerParams, schema))
   }
 
   private def strategy(caseInsensitiveParams: Map[String, String]) =
@@ -450,4 +488,27 @@ private[kafka010] object KafkaSourceProvider extends Logging {
 
     def build(): ju.Map[String, Object] = map
   }
+
+  private[kafka010] def kafkaParamsForProducer(
+      parameters: Map[String, String]): Map[String, String] = {
+    val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
+    if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) {
+      throw new IllegalArgumentException(
+        s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys "
+          + "are serialized with ByteArraySerializer.")
+    }
+
+    if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}"))
+    {
+      throw new IllegalArgumentException(
+        s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as "
+          + "value are serialized with ByteArraySerializer.")
+    }
+    parameters
+      .keySet
+      .filter(_.toLowerCase(Locale.ROOT).startsWith("kafka."))
+      .map { k => k.drop(6).toString -> parameters(k) }
+      .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName,
+      ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala
index 6fd333e..baa60fe 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala
@@ -33,10 +33,8 @@ import org.apache.spark.sql.types.{BinaryType, StringType}
 private[kafka010] class KafkaWriteTask(
     producerConfiguration: ju.Map[String, Object],
     inputSchema: Seq[Attribute],
-    topic: Option[String]) {
+    topic: Option[String]) extends KafkaRowWriter(inputSchema, topic) {
   // used to synchronize with Kafka callbacks
-  @volatile private var failedWrite: Exception = null
-  private val projection = createProjection
   private var producer: KafkaProducer[Array[Byte], Array[Byte]] = _
 
   /**
@@ -46,23 +44,7 @@ private[kafka010] class KafkaWriteTask(
     producer = CachedKafkaProducer.getOrCreate(producerConfiguration)
     while (iterator.hasNext && failedWrite == null) {
       val currentRow = iterator.next()
-      val projectedRow = projection(currentRow)
-      val topic = projectedRow.getUTF8String(0)
-      val key = projectedRow.getBinary(1)
-      val value = projectedRow.getBinary(2)
-      if (topic == null) {
-        throw new NullPointerException(s"null topic present in the data. Use the " +
-        s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.")
-      }
-      val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value)
-      val callback = new Callback() {
-        override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = {
-          if (failedWrite == null && e != null) {
-            failedWrite = e
-          }
-        }
-      }
-      producer.send(record, callback)
+      sendRow(currentRow, producer)
     }
   }
 
@@ -74,8 +56,49 @@ private[kafka010] class KafkaWriteTask(
       producer = null
     }
   }
+}
+
+private[kafka010] abstract class KafkaRowWriter(
+    inputSchema: Seq[Attribute], topic: Option[String]) {
+
+  // used to synchronize with Kafka callbacks
+  @volatile protected var failedWrite: Exception = _
+  protected val projection = createProjection
+
+  private val callback = new Callback() {
+    override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = {
+      if (failedWrite == null && e != null) {
+        failedWrite = e
+      }
+    }
+  }
 
-  private def createProjection: UnsafeProjection = {
+  /**
+   * Send the specified row to the producer, with a callback that will save any exception
+   * to failedWrite. Note that send is asynchronous; subclasses must flush() their producer before
+   * assuming the row is in Kafka.
+   */
+  protected def sendRow(
+      row: InternalRow, producer: KafkaProducer[Array[Byte], Array[Byte]]): Unit = {
+    val projectedRow = projection(row)
+    val topic = projectedRow.getUTF8String(0)
+    val key = projectedRow.getBinary(1)
+    val value = projectedRow.getBinary(2)
+    if (topic == null) {
+      throw new NullPointerException(s"null topic present in the data. Use the " +
+        s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.")
+    }
+    val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value)
+    producer.send(record, callback)
+  }
+
+  protected def checkForErrors(): Unit = {
+    if (failedWrite != null) {
+      throw failedWrite
+    }
+  }
+
+  private def createProjection = {
     val topicExpression = topic.map(Literal(_)).orElse {
       inputSchema.find(_.name == KafkaWriter.TOPIC_ATTRIBUTE_NAME)
     }.getOrElse {
@@ -112,11 +135,5 @@ private[kafka010] class KafkaWriteTask(
       Seq(topicExpression, Cast(keyExpression, BinaryType),
         Cast(valueExpression, BinaryType)), inputSchema)
   }
-
-  private def checkForErrors(): Unit = {
-    if (failedWrite != null) {
-      throw failedWrite
-    }
-  }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala
index 5e9ae35..15cd448 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala
@@ -43,10 +43,9 @@ private[kafka010] object KafkaWriter extends Logging {
   override def toString: String = "KafkaWriter"
 
   def validateQuery(
-      queryExecution: QueryExecution,
+      schema: Seq[Attribute],
       kafkaParameters: ju.Map[String, Object],
       topic: Option[String] = None): Unit = {
-    val schema = queryExecution.analyzed.output
     schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse(
       if (topic.isEmpty) {
         throw new AnalysisException(s"topic option required when no " +
@@ -84,7 +83,7 @@ private[kafka010] object KafkaWriter extends Logging {
       kafkaParameters: ju.Map[String, Object],
       topic: Option[String] = None): Unit = {
     val schema = queryExecution.analyzed.output
-    validateQuery(queryExecution, kafkaParameters, topic)
+    validateQuery(schema, kafkaParameters, topic)
     queryExecution.toRdd.foreachPartition { iter =>
       val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic)
       Utils.tryWithSafeFinally(block = writeTask.execute(iter))(

http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala
new file mode 100644
index 0000000..dfc97b1
--- /dev/null
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala
@@ -0,0 +1,474 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import java.util.Locale
+import java.util.concurrent.atomic.AtomicInteger
+
+import org.apache.kafka.clients.producer.ProducerConfig
+import org.apache.kafka.common.serialization.ByteArraySerializer
+import org.scalatest.time.SpanSugar._
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection}
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.streaming._
+import org.apache.spark.sql.types.{BinaryType, DataType}
+import org.apache.spark.util.Utils
+
+/**
+ * This is a temporary port of KafkaSinkSuite, since we do not yet have a V2 memory stream.
+ * Once we have one, this will be changed to a specialization of KafkaSinkSuite and we won't have
+ * to duplicate all the code.
+ */
+class KafkaContinuousSinkSuite extends KafkaContinuousTest {
+  import testImplicits._
+
+  override val streamingTimeout = 30.seconds
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    testUtils = new KafkaTestUtils(
+      withBrokerProps = Map("auto.create.topics.enable" -> "false"))
+    testUtils.setup()
+  }
+
+  override def afterAll(): Unit = {
+    if (testUtils != null) {
+      testUtils.teardown()
+      testUtils = null
+    }
+    super.afterAll()
+  }
+
+  test("streaming - write to kafka with topic field") {
+    val inputTopic = newTopic()
+    testUtils.createTopic(inputTopic, partitions = 1)
+
+    val input = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("subscribe", inputTopic)
+      .option("startingOffsets", "earliest")
+      .load()
+
+    val topic = newTopic()
+    testUtils.createTopic(topic)
+
+    val writer = createKafkaWriter(
+      input.toDF(),
+      withTopic = None,
+      withOutputMode = Some(OutputMode.Append))(
+      withSelectExpr = s"'$topic' as topic", "value")
+
+    val reader = createKafkaReader(topic)
+      .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value")
+      .selectExpr("CAST(key as INT) key", "CAST(value as INT) value")
+      .as[(Int, Int)]
+      .map(_._2)
+
+    try {
+      testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
+      failAfter(streamingTimeout) {
+        writer.processAllAvailable()
+      }
+      checkDatasetUnorderly(reader, 1, 2, 3, 4, 5)
+      testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10"))
+      failAfter(streamingTimeout) {
+        writer.processAllAvailable()
+      }
+      checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
+    } finally {
+      writer.stop()
+    }
+  }
+
+  test("streaming - write w/o topic field, with topic option") {
+    val inputTopic = newTopic()
+    testUtils.createTopic(inputTopic, partitions = 1)
+
+    val input = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("subscribe", inputTopic)
+      .option("startingOffsets", "earliest")
+      .load()
+
+    val topic = newTopic()
+    testUtils.createTopic(topic)
+
+    val writer = createKafkaWriter(
+      input.toDF(),
+      withTopic = Some(topic),
+      withOutputMode = Some(OutputMode.Append()))()
+
+    val reader = createKafkaReader(topic)
+      .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value")
+      .selectExpr("CAST(key as INT) key", "CAST(value as INT) value")
+      .as[(Int, Int)]
+      .map(_._2)
+
+    try {
+      testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
+      failAfter(streamingTimeout) {
+        writer.processAllAvailable()
+      }
+      checkDatasetUnorderly(reader, 1, 2, 3, 4, 5)
+      testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10"))
+      failAfter(streamingTimeout) {
+        writer.processAllAvailable()
+      }
+      checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
+    } finally {
+      writer.stop()
+    }
+  }
+
+  test("streaming - topic field and topic option") {
+    /* The purpose of this test is to ensure that the topic option
+     * overrides the topic field. We begin by writing some data that
+     * includes a topic field and value (e.g., 'foo') along with a topic
+     * option. Then when we read from the topic specified in the option
+     * we should see the data i.e., the data was written to the topic
+     * option, and not to the topic in the data e.g., foo
+     */
+    val inputTopic = newTopic()
+    testUtils.createTopic(inputTopic, partitions = 1)
+
+    val input = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("subscribe", inputTopic)
+      .option("startingOffsets", "earliest")
+      .load()
+
+    val topic = newTopic()
+    testUtils.createTopic(topic)
+
+    val writer = createKafkaWriter(
+      input.toDF(),
+      withTopic = Some(topic),
+      withOutputMode = Some(OutputMode.Append()))(
+      withSelectExpr = "'foo' as topic", "CAST(value as STRING) value")
+
+    val reader = createKafkaReader(topic)
+      .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
+      .selectExpr("CAST(key AS INT)", "CAST(value AS INT)")
+      .as[(Int, Int)]
+      .map(_._2)
+
+    try {
+      testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
+      failAfter(streamingTimeout) {
+        writer.processAllAvailable()
+      }
+      checkDatasetUnorderly(reader, 1, 2, 3, 4, 5)
+      testUtils.sendMessages(inputTopic, Array("6", "7", "8", "9", "10"))
+      failAfter(streamingTimeout) {
+        writer.processAllAvailable()
+      }
+      checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
+    } finally {
+      writer.stop()
+    }
+  }
+
+  test("null topic attribute") {
+    val inputTopic = newTopic()
+    testUtils.createTopic(inputTopic, partitions = 1)
+
+    val input = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("subscribe", inputTopic)
+      .option("startingOffsets", "earliest")
+      .load()
+    val topic = newTopic()
+    testUtils.createTopic(topic)
+
+    /* No topic field or topic option */
+    var writer: StreamingQuery = null
+    var ex: Exception = null
+    try {
+      ex = intercept[StreamingQueryException] {
+        writer = createKafkaWriter(input.toDF())(
+          withSelectExpr = "CAST(null as STRING) as topic", "value"
+        )
+        testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
+        writer.processAllAvailable()
+      }
+    } finally {
+      writer.stop()
+    }
+    assert(ex.getCause.getCause.getMessage
+      .toLowerCase(Locale.ROOT)
+      .contains("null topic present in the data."))
+  }
+
+  test("streaming - write data with bad schema") {
+    val inputTopic = newTopic()
+    testUtils.createTopic(inputTopic, partitions = 1)
+
+    val input = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("subscribe", inputTopic)
+      .option("startingOffsets", "earliest")
+      .load()
+    val topic = newTopic()
+    testUtils.createTopic(topic)
+
+    /* No topic field or topic option */
+    var writer: StreamingQuery = null
+    var ex: Exception = null
+    try {
+      ex = intercept[StreamingQueryException] {
+        writer = createKafkaWriter(input.toDF())(
+          withSelectExpr = "value as key", "value"
+        )
+        testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
+        writer.processAllAvailable()
+      }
+    } finally {
+      writer.stop()
+    }
+    assert(ex.getMessage
+      .toLowerCase(Locale.ROOT)
+      .contains("topic option required when no 'topic' attribute is present"))
+
+    try {
+      /* No value field */
+      ex = intercept[StreamingQueryException] {
+        writer = createKafkaWriter(input.toDF())(
+          withSelectExpr = s"'$topic' as topic", "value as key"
+        )
+        testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
+        writer.processAllAvailable()
+      }
+    } finally {
+      writer.stop()
+    }
+    assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(
+      "required attribute 'value' not found"))
+  }
+
+  test("streaming - write data with valid schema but wrong types") {
+    val inputTopic = newTopic()
+    testUtils.createTopic(inputTopic, partitions = 1)
+
+    val input = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("subscribe", inputTopic)
+      .option("startingOffsets", "earliest")
+      .load()
+      .selectExpr("CAST(value as STRING) value")
+    val topic = newTopic()
+    testUtils.createTopic(topic)
+
+    var writer: StreamingQuery = null
+    var ex: Exception = null
+    try {
+      /* topic field wrong type */
+      ex = intercept[StreamingQueryException] {
+        writer = createKafkaWriter(input.toDF())(
+          withSelectExpr = s"CAST('1' as INT) as topic", "value"
+        )
+        testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
+        writer.processAllAvailable()
+      }
+    } finally {
+      writer.stop()
+    }
+    assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("topic type must be a string"))
+
+    try {
+      /* value field wrong type */
+      ex = intercept[StreamingQueryException] {
+        writer = createKafkaWriter(input.toDF())(
+          withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as value"
+        )
+        testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
+        writer.processAllAvailable()
+      }
+    } finally {
+      writer.stop()
+    }
+    assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(
+      "value attribute type must be a string or binarytype"))
+
+    try {
+      ex = intercept[StreamingQueryException] {
+        /* key field wrong type */
+        writer = createKafkaWriter(input.toDF())(
+          withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as key", "value"
+        )
+        testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
+        writer.processAllAvailable()
+      }
+    } finally {
+      writer.stop()
+    }
+    assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(
+      "key attribute type must be a string or binarytype"))
+  }
+
+  test("streaming - write to non-existing topic") {
+    val inputTopic = newTopic()
+    testUtils.createTopic(inputTopic, partitions = 1)
+
+    val input = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("subscribe", inputTopic)
+      .option("startingOffsets", "earliest")
+      .load()
+    val topic = newTopic()
+
+    var writer: StreamingQuery = null
+    var ex: Exception = null
+    try {
+      ex = intercept[StreamingQueryException] {
+        writer = createKafkaWriter(input.toDF(), withTopic = Some(topic))()
+        testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5"))
+        eventually(timeout(streamingTimeout)) {
+          assert(writer.exception.isDefined)
+        }
+        throw writer.exception.get
+      }
+    } finally {
+      writer.stop()
+    }
+    assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("job aborted"))
+  }
+
+  test("streaming - exception on config serializer") {
+    val inputTopic = newTopic()
+    testUtils.createTopic(inputTopic, partitions = 1)
+    testUtils.sendMessages(inputTopic, Array("0"))
+
+    val input = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("subscribe", inputTopic)
+      .load()
+    var writer: StreamingQuery = null
+    var ex: Exception = null
+    try {
+      ex = intercept[StreamingQueryException] {
+        writer = createKafkaWriter(
+          input.toDF(),
+          withOptions = Map("kafka.key.serializer" -> "foo"))()
+        writer.processAllAvailable()
+      }
+      assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(
+        "kafka option 'key.serializer' is not supported"))
+    } finally {
+      writer.stop()
+    }
+
+    try {
+      ex = intercept[StreamingQueryException] {
+        writer = createKafkaWriter(
+          input.toDF(),
+          withOptions = Map("kafka.value.serializer" -> "foo"))()
+        writer.processAllAvailable()
+      }
+      assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(
+        "kafka option 'value.serializer' is not supported"))
+    } finally {
+      writer.stop()
+    }
+  }
+
+  test("generic - write big data with small producer buffer") {
+    /* This test ensures that we understand the semantics of Kafka when
+    * is comes to blocking on a call to send when the send buffer is full.
+    * This test will configure the smallest possible producer buffer and
+    * indicate that we should block when it is full. Thus, no exception should
+    * be thrown in the case of a full buffer.
+    */
+    val topic = newTopic()
+    testUtils.createTopic(topic, 1)
+    val options = new java.util.HashMap[String, String]
+    options.put("bootstrap.servers", testUtils.brokerAddress)
+    options.put("buffer.memory", "16384") // min buffer size
+    options.put("block.on.buffer.full", "true")
+    options.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName)
+    options.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName)
+    val inputSchema = Seq(AttributeReference("value", BinaryType)())
+    val data = new Array[Byte](15000) // large value
+    val writeTask = new KafkaContinuousDataWriter(Some(topic), options.asScala.toMap, inputSchema)
+    try {
+      val fieldTypes: Array[DataType] = Array(BinaryType)
+      val converter = UnsafeProjection.create(fieldTypes)
+      val row = new SpecificInternalRow(fieldTypes)
+      row.update(0, data)
+      val iter = Seq.fill(1000)(converter.apply(row)).iterator
+      iter.foreach(writeTask.write(_))
+      writeTask.commit()
+    } finally {
+      writeTask.close()
+    }
+  }
+
+  private def createKafkaReader(topic: String): DataFrame = {
+    spark.read
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("startingOffsets", "earliest")
+      .option("endingOffsets", "latest")
+      .option("subscribe", topic)
+      .load()
+  }
+
+  private def createKafkaWriter(
+      input: DataFrame,
+      withTopic: Option[String] = None,
+      withOutputMode: Option[OutputMode] = None,
+      withOptions: Map[String, String] = Map[String, String]())
+      (withSelectExpr: String*): StreamingQuery = {
+    var stream: DataStreamWriter[Row] = null
+    val checkpointDir = Utils.createTempDir()
+    var df = input.toDF()
+    if (withSelectExpr.length > 0) {
+      df = df.selectExpr(withSelectExpr: _*)
+    }
+    stream = df.writeStream
+      .format("kafka")
+      .option("checkpointLocation", checkpointDir.getCanonicalPath)
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      // We need to reduce blocking time to efficiently test non-existent partition behavior.
+      .option("kafka.max.block.ms", "1000")
+      .trigger(Trigger.Continuous(1000))
+      .queryName("kafkaStream")
+    withTopic.foreach(stream.option("topic", _))
+    withOutputMode.foreach(stream.outputMode(_))
+    withOptions.foreach(opt => stream.option(opt._1, opt._2))
+    stream.start()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala
new file mode 100644
index 0000000..b3dade4
--- /dev/null
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import java.util.Properties
+import java.util.concurrent.atomic.AtomicInteger
+
+import org.scalatest.time.SpanSugar._
+import scala.collection.mutable
+import scala.util.Random
+
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row}
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+import org.apache.spark.sql.execution.streaming.StreamExecution
+import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
+import org.apache.spark.sql.streaming.{StreamTest, Trigger}
+import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession}
+
+// Run tests in KafkaSourceSuiteBase in continuous execution mode.
+class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest
+
+class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest {
+  import testImplicits._
+
+  override val brokerProps = Map("auto.create.topics.enable" -> "false")
+
+  test("subscribing topic by pattern with topic deletions") {
+    val topicPrefix = newTopic()
+    val topic = topicPrefix + "-seems"
+    val topic2 = topicPrefix + "-bad"
+    testUtils.createTopic(topic, partitions = 5)
+    testUtils.sendMessages(topic, Array("-1"))
+    require(testUtils.getLatestOffsets(Set(topic)).size === 5)
+
+    val reader = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("kafka.metadata.max.age.ms", "1")
+      .option("subscribePattern", s"$topicPrefix-.*")
+      .option("failOnDataLoss", "false")
+
+    val kafka = reader.load()
+      .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
+      .as[(String, String)]
+    val mapped = kafka.map(kv => kv._2.toInt + 1)
+
+    testStream(mapped)(
+      makeSureGetOffsetCalled,
+      AddKafkaData(Set(topic), 1, 2, 3),
+      CheckAnswer(2, 3, 4),
+      Execute { query =>
+        testUtils.deleteTopic(topic)
+        testUtils.createTopic(topic2, partitions = 5)
+        eventually(timeout(streamingTimeout)) {
+          assert(
+            query.lastExecution.logical.collectFirst {
+              case DataSourceV2Relation(_, r: KafkaContinuousReader) => r
+            }.exists { r =>
+              // Ensure the new topic is present and the old topic is gone.
+              r.knownPartitions.exists(_.topic == topic2)
+            },
+            s"query never reconfigured to new topic $topic2")
+        }
+      },
+      AddKafkaData(Set(topic2), 4, 5, 6),
+      CheckAnswer(2, 3, 4, 5, 6, 7)
+    )
+  }
+}
+
+class KafkaContinuousSourceStressForDontFailOnDataLossSuite
+    extends KafkaSourceStressForDontFailOnDataLossSuite {
+  override protected def startStream(ds: Dataset[Int]) = {
+    ds.writeStream
+      .format("memory")
+      .queryName("memory")
+      .start()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala
new file mode 100644
index 0000000..e713e66
--- /dev/null
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.kafka010
+
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+import org.apache.spark.sql.execution.streaming.StreamExecution
+import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.sql.test.TestSparkSession
+
+// Trait to configure StreamTest for kafka continuous execution tests.
+trait KafkaContinuousTest extends KafkaSourceTest {
+  override val defaultTrigger = Trigger.Continuous(1000)
+  override val defaultUseV2Sink = true
+
+  // We need more than the default local[2] to be able to schedule all partitions simultaneously.
+  override protected def createSparkSession = new TestSparkSession(
+    new SparkContext(
+      "local[10]",
+      "continuous-stream-test-sql-context",
+      sparkConf.set("spark.sql.testkey", "true")))
+
+  // In addition to setting the partitions in Kafka, we have to wait until the query has
+  // reconfigured to the new count so the test framework can hook in properly.
+  override protected def setTopicPartitions(
+      topic: String, newCount: Int, query: StreamExecution) = {
+    testUtils.addPartitions(topic, newCount)
+    eventually(timeout(streamingTimeout)) {
+      assert(
+        query.lastExecution.logical.collectFirst {
+          case DataSourceV2Relation(_, r: KafkaContinuousReader) => r
+        }.exists(_.knownPartitions.size == newCount),
+        s"query never reconfigured to $newCount partitions")
+    }
+  }
+
+  test("ensure continuous stream is being used") {
+    val query = spark.readStream
+      .format("rate")
+      .option("numPartitions", "1")
+      .option("rowsPerSecond", "1")
+      .load()
+
+    testStream(query)(
+      Execute(q => assert(q.isInstanceOf[ContinuousExecution]))
+    )
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala
----------------------------------------------------------------------
diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala
index 2034b9b..d66908f8 100644
--- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala
+++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala
@@ -34,11 +34,14 @@ import org.scalatest.concurrent.PatienceConfiguration.Timeout
 import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.SparkContext
-import org.apache.spark.sql.ForeachWriter
+import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row}
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2Exec}
 import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
+import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryWriter
 import org.apache.spark.sql.functions.{count, window}
 import org.apache.spark.sql.kafka010.KafkaSourceProvider._
-import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest}
+import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest, Trigger}
 import org.apache.spark.sql.streaming.util.StreamManualClock
 import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession}
 import org.apache.spark.util.Utils
@@ -49,9 +52,11 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
 
   override val streamingTimeout = 30.seconds
 
+  protected val brokerProps = Map[String, Object]()
+
   override def beforeAll(): Unit = {
     super.beforeAll()
-    testUtils = new KafkaTestUtils
+    testUtils = new KafkaTestUtils(brokerProps)
     testUtils.setup()
   }
 
@@ -59,18 +64,25 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
     if (testUtils != null) {
       testUtils.teardown()
       testUtils = null
-      super.afterAll()
     }
+    super.afterAll()
   }
 
   protected def makeSureGetOffsetCalled = AssertOnQuery { q =>
     // Because KafkaSource's initialPartitionOffsets is set lazily, we need to make sure
-    // its "getOffset" is called before pushing any data. Otherwise, because of the race contion,
+    // its "getOffset" is called before pushing any data. Otherwise, because of the race condition,
     // we don't know which data should be fetched when `startingOffsets` is latest.
-    q.processAllAvailable()
+    q match {
+      case c: ContinuousExecution => c.awaitEpoch(0)
+      case m: MicroBatchExecution => m.processAllAvailable()
+    }
     true
   }
 
+  protected def setTopicPartitions(topic: String, newCount: Int, query: StreamExecution) : Unit = {
+    testUtils.addPartitions(topic, newCount)
+  }
+
   /**
    * Add data to Kafka.
    *
@@ -82,7 +94,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
       message: String = "",
       topicAction: (String, Option[Int]) => Unit = (_, _) => {}) extends AddData {
 
-    override def addData(query: Option[StreamExecution]): (Source, Offset) = {
+    override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = {
       if (query.get.isActive) {
         // Make sure no Spark job is running when deleting a topic
         query.get.processAllAvailable()
@@ -97,16 +109,18 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
         topicAction(existingTopicPartitions._1, Some(existingTopicPartitions._2))
       }
 
-      // Read all topics again in case some topics are delete.
-      val allTopics = testUtils.getAllTopicsAndPartitionSize().toMap.keys
       require(
         query.nonEmpty,
         "Cannot add data when there is no query for finding the active kafka source")
 
       val sources = query.get.logicalPlan.collect {
-        case StreamingExecutionRelation(source, _) if source.isInstanceOf[KafkaSource] =>
-          source.asInstanceOf[KafkaSource]
-      }
+        case StreamingExecutionRelation(source: KafkaSource, _) => source
+      } ++ (query.get.lastExecution match {
+        case null => Seq()
+        case e => e.logical.collect {
+          case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader
+        }
+      })
       if (sources.isEmpty) {
         throw new Exception(
           "Could not find Kafka source in the StreamExecution logical plan to add data to")
@@ -137,14 +151,158 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
     override def toString: String =
       s"AddKafkaData(topics = $topics, data = $data, message = $message)"
   }
-}
 
+  private val topicId = new AtomicInteger(0)
+  protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}"
+}
 
-class KafkaSourceSuite extends KafkaSourceTest {
+class KafkaMicroBatchSourceSuite extends KafkaSourceSuiteBase {
 
   import testImplicits._
 
-  private val topicId = new AtomicInteger(0)
+  test("(de)serialization of initial offsets") {
+    val topic = newTopic()
+    testUtils.createTopic(topic, partitions = 5)
+
+    val reader = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("subscribe", topic)
+
+    testStream(reader.load)(
+      makeSureGetOffsetCalled,
+      StopStream,
+      StartStream(),
+      StopStream)
+  }
+
+  test("maxOffsetsPerTrigger") {
+    val topic = newTopic()
+    testUtils.createTopic(topic, partitions = 3)
+    testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0))
+    testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1))
+    testUtils.sendMessages(topic, Array("1"), Some(2))
+
+    val reader = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("kafka.metadata.max.age.ms", "1")
+      .option("maxOffsetsPerTrigger", 10)
+      .option("subscribe", topic)
+      .option("startingOffsets", "earliest")
+    val kafka = reader.load()
+      .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
+      .as[(String, String)]
+    val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt)
+
+    val clock = new StreamManualClock
+
+    val waitUntilBatchProcessed = AssertOnQuery { q =>
+      eventually(Timeout(streamingTimeout)) {
+        if (!q.exception.isDefined) {
+          assert(clock.isStreamWaitingAt(clock.getTimeMillis()))
+        }
+      }
+      if (q.exception.isDefined) {
+        throw q.exception.get
+      }
+      true
+    }
+
+    testStream(mapped)(
+      StartStream(ProcessingTime(100), clock),
+      waitUntilBatchProcessed,
+      // 1 from smallest, 1 from middle, 8 from biggest
+      CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107),
+      AdvanceManualClock(100),
+      waitUntilBatchProcessed,
+      // smallest now empty, 1 more from middle, 9 more from biggest
+      CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
+        11, 108, 109, 110, 111, 112, 113, 114, 115, 116
+      ),
+      StopStream,
+      StartStream(ProcessingTime(100), clock),
+      waitUntilBatchProcessed,
+      // smallest now empty, 1 more from middle, 9 more from biggest
+      CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
+        11, 108, 109, 110, 111, 112, 113, 114, 115, 116,
+        12, 117, 118, 119, 120, 121, 122, 123, 124, 125
+      ),
+      AdvanceManualClock(100),
+      waitUntilBatchProcessed,
+      // smallest now empty, 1 more from middle, 9 more from biggest
+      CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
+        11, 108, 109, 110, 111, 112, 113, 114, 115, 116,
+        12, 117, 118, 119, 120, 121, 122, 123, 124, 125,
+        13, 126, 127, 128, 129, 130, 131, 132, 133, 134
+      )
+    )
+  }
+
+  test("input row metrics") {
+    val topic = newTopic()
+    testUtils.createTopic(topic, partitions = 5)
+    testUtils.sendMessages(topic, Array("-1"))
+    require(testUtils.getLatestOffsets(Set(topic)).size === 5)
+
+    val kafka = spark
+      .readStream
+      .format("kafka")
+      .option("subscribe", topic)
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .load()
+      .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
+      .as[(String, String)]
+
+    val mapped = kafka.map(kv => kv._2.toInt + 1)
+    testStream(mapped)(
+      StartStream(trigger = ProcessingTime(1)),
+      makeSureGetOffsetCalled,
+      AddKafkaData(Set(topic), 1, 2, 3),
+      CheckAnswer(2, 3, 4),
+      AssertOnQuery { query =>
+        val recordsRead = query.recentProgress.map(_.numInputRows).sum
+        recordsRead == 3
+      }
+    )
+  }
+
+  test("subscribing topic by pattern with topic deletions") {
+    val topicPrefix = newTopic()
+    val topic = topicPrefix + "-seems"
+    val topic2 = topicPrefix + "-bad"
+    testUtils.createTopic(topic, partitions = 5)
+    testUtils.sendMessages(topic, Array("-1"))
+    require(testUtils.getLatestOffsets(Set(topic)).size === 5)
+
+    val reader = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("kafka.metadata.max.age.ms", "1")
+      .option("subscribePattern", s"$topicPrefix-.*")
+      .option("failOnDataLoss", "false")
+
+    val kafka = reader.load()
+      .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
+      .as[(String, String)]
+    val mapped = kafka.map(kv => kv._2.toInt + 1)
+
+    testStream(mapped)(
+      makeSureGetOffsetCalled,
+      AddKafkaData(Set(topic), 1, 2, 3),
+      CheckAnswer(2, 3, 4),
+      Assert {
+        testUtils.deleteTopic(topic)
+        testUtils.createTopic(topic2, partitions = 5)
+        true
+      },
+      AddKafkaData(Set(topic2), 4, 5, 6),
+      CheckAnswer(2, 3, 4, 5, 6, 7)
+    )
+  }
 
   testWithUninterruptibleThread(
     "deserialization of initial offset with Spark 2.1.0") {
@@ -237,86 +395,51 @@ class KafkaSourceSuite extends KafkaSourceTest {
     }
   }
 
-  test("(de)serialization of initial offsets") {
-    val topic = newTopic()
-    testUtils.createTopic(topic, partitions = 64)
-
-    val reader = spark
-      .readStream
-      .format("kafka")
-      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
-      .option("subscribe", topic)
-
-    testStream(reader.load)(
-      makeSureGetOffsetCalled,
-      StopStream,
-      StartStream(),
-      StopStream)
-  }
-
-  test("maxOffsetsPerTrigger") {
+  test("KafkaSource with watermark") {
+    val now = System.currentTimeMillis()
     val topic = newTopic()
-    testUtils.createTopic(topic, partitions = 3)
-    testUtils.sendMessages(topic, (100 to 200).map(_.toString).toArray, Some(0))
-    testUtils.sendMessages(topic, (10 to 20).map(_.toString).toArray, Some(1))
-    testUtils.sendMessages(topic, Array("1"), Some(2))
+    testUtils.createTopic(newTopic(), partitions = 1)
+    testUtils.sendMessages(topic, Array(1).map(_.toString))
 
-    val reader = spark
+    val kafka = spark
       .readStream
       .format("kafka")
       .option("kafka.bootstrap.servers", testUtils.brokerAddress)
       .option("kafka.metadata.max.age.ms", "1")
-      .option("maxOffsetsPerTrigger", 10)
+      .option("startingOffsets", s"earliest")
       .option("subscribe", topic)
-      .option("startingOffsets", "earliest")
-    val kafka = reader.load()
-      .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
-      .as[(String, String)]
-    val mapped: org.apache.spark.sql.Dataset[_] = kafka.map(kv => kv._2.toInt)
-
-    val clock = new StreamManualClock
+      .load()
 
-    val waitUntilBatchProcessed = AssertOnQuery { q =>
-      eventually(Timeout(streamingTimeout)) {
-        if (!q.exception.isDefined) {
-          assert(clock.isStreamWaitingAt(clock.getTimeMillis()))
-        }
-      }
-      if (q.exception.isDefined) {
-        throw q.exception.get
-      }
-      true
-    }
+    val windowedAggregation = kafka
+      .withWatermark("timestamp", "10 seconds")
+      .groupBy(window($"timestamp", "5 seconds") as 'window)
+      .agg(count("*") as 'count)
+      .select($"window".getField("start") as 'window, $"count")
 
-    testStream(mapped)(
-      StartStream(ProcessingTime(100), clock),
-      waitUntilBatchProcessed,
-      // 1 from smallest, 1 from middle, 8 from biggest
-      CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107),
-      AdvanceManualClock(100),
-      waitUntilBatchProcessed,
-      // smallest now empty, 1 more from middle, 9 more from biggest
-      CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
-        11, 108, 109, 110, 111, 112, 113, 114, 115, 116
-      ),
-      StopStream,
-      StartStream(ProcessingTime(100), clock),
-      waitUntilBatchProcessed,
-      // smallest now empty, 1 more from middle, 9 more from biggest
-      CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
-        11, 108, 109, 110, 111, 112, 113, 114, 115, 116,
-        12, 117, 118, 119, 120, 121, 122, 123, 124, 125
-      ),
-      AdvanceManualClock(100),
-      waitUntilBatchProcessed,
-      // smallest now empty, 1 more from middle, 9 more from biggest
-      CheckAnswer(1, 10, 100, 101, 102, 103, 104, 105, 106, 107,
-        11, 108, 109, 110, 111, 112, 113, 114, 115, 116,
-        12, 117, 118, 119, 120, 121, 122, 123, 124, 125,
-        13, 126, 127, 128, 129, 130, 131, 132, 133, 134
-      )
-    )
+    val query = windowedAggregation
+      .writeStream
+      .format("memory")
+      .outputMode("complete")
+      .queryName("kafkaWatermark")
+      .start()
+    query.processAllAvailable()
+    val rows = spark.table("kafkaWatermark").collect()
+    assert(rows.length === 1, s"Unexpected results: ${rows.toList}")
+    val row = rows(0)
+    // We cannot check the exact window start time as it depands on the time that messages were
+    // inserted by the producer. So here we just use a low bound to make sure the internal
+    // conversion works.
+    assert(
+      row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000,
+      s"Unexpected results: $row")
+    assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row")
+    query.stop()
   }
+}
+
+class KafkaSourceSuiteBase extends KafkaSourceTest {
+
+  import testImplicits._
 
   test("cannot stop Kafka stream") {
     val topic = newTopic()
@@ -328,7 +451,7 @@ class KafkaSourceSuite extends KafkaSourceTest {
       .format("kafka")
       .option("kafka.bootstrap.servers", testUtils.brokerAddress)
       .option("kafka.metadata.max.age.ms", "1")
-      .option("subscribePattern", s"topic-.*")
+      .option("subscribePattern", s"$topic.*")
 
     val kafka = reader.load()
       .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
@@ -422,65 +545,6 @@ class KafkaSourceSuite extends KafkaSourceTest {
     }
   }
 
-  test("subscribing topic by pattern with topic deletions") {
-    val topicPrefix = newTopic()
-    val topic = topicPrefix + "-seems"
-    val topic2 = topicPrefix + "-bad"
-    testUtils.createTopic(topic, partitions = 5)
-    testUtils.sendMessages(topic, Array("-1"))
-    require(testUtils.getLatestOffsets(Set(topic)).size === 5)
-
-    val reader = spark
-      .readStream
-      .format("kafka")
-      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
-      .option("kafka.metadata.max.age.ms", "1")
-      .option("subscribePattern", s"$topicPrefix-.*")
-      .option("failOnDataLoss", "false")
-
-    val kafka = reader.load()
-      .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
-      .as[(String, String)]
-    val mapped = kafka.map(kv => kv._2.toInt + 1)
-
-    testStream(mapped)(
-      makeSureGetOffsetCalled,
-      AddKafkaData(Set(topic), 1, 2, 3),
-      CheckAnswer(2, 3, 4),
-      Assert {
-        testUtils.deleteTopic(topic)
-        testUtils.createTopic(topic2, partitions = 5)
-        true
-      },
-      AddKafkaData(Set(topic2), 4, 5, 6),
-      CheckAnswer(2, 3, 4, 5, 6, 7)
-    )
-  }
-
-  test("starting offset is latest by default") {
-    val topic = newTopic()
-    testUtils.createTopic(topic, partitions = 5)
-    testUtils.sendMessages(topic, Array("0"))
-    require(testUtils.getLatestOffsets(Set(topic)).size === 5)
-
-    val reader = spark
-      .readStream
-      .format("kafka")
-      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
-      .option("subscribe", topic)
-
-    val kafka = reader.load()
-      .selectExpr("CAST(value AS STRING)")
-      .as[String]
-    val mapped = kafka.map(_.toInt)
-
-    testStream(mapped)(
-      makeSureGetOffsetCalled,
-      AddKafkaData(Set(topic), 1, 2, 3),
-      CheckAnswer(1, 2, 3)  // should not have 0
-    )
-  }
-
   test("bad source options") {
     def testBadOptions(options: (String, String)*)(expectedMsgs: String*): Unit = {
       val ex = intercept[IllegalArgumentException] {
@@ -540,34 +604,6 @@ class KafkaSourceSuite extends KafkaSourceTest {
     testUnsupportedConfig("kafka.auto.offset.reset", "latest")
   }
 
-  test("input row metrics") {
-    val topic = newTopic()
-    testUtils.createTopic(topic, partitions = 5)
-    testUtils.sendMessages(topic, Array("-1"))
-    require(testUtils.getLatestOffsets(Set(topic)).size === 5)
-
-    val kafka = spark
-      .readStream
-      .format("kafka")
-      .option("subscribe", topic)
-      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
-      .load()
-      .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
-      .as[(String, String)]
-
-    val mapped = kafka.map(kv => kv._2.toInt + 1)
-    testStream(mapped)(
-      StartStream(trigger = ProcessingTime(1)),
-      makeSureGetOffsetCalled,
-      AddKafkaData(Set(topic), 1, 2, 3),
-      CheckAnswer(2, 3, 4),
-      AssertOnQuery { query =>
-        val recordsRead = query.recentProgress.map(_.numInputRows).sum
-        recordsRead == 3
-      }
-    )
-  }
-
   test("delete a topic when a Spark job is running") {
     KafkaSourceSuite.collectedData.clear()
 
@@ -629,8 +665,6 @@ class KafkaSourceSuite extends KafkaSourceTest {
     }
   }
 
-  private def newTopic(): String = s"topic-${topicId.getAndIncrement()}"
-
   private def assignString(topic: String, partitions: Iterable[Int]): String = {
     JsonUtils.partitions(partitions.map(p => new TopicPartition(topic, p)))
   }
@@ -676,6 +710,10 @@ class KafkaSourceSuite extends KafkaSourceTest {
 
     testStream(mapped)(
       makeSureGetOffsetCalled,
+      Execute { q =>
+        // wait to reach the last offset in every partition
+        q.awaitOffset(0, KafkaSourceOffset(partitionOffsets.mapValues(_ => 3L)))
+      },
       CheckAnswer(-20, -21, -22, 0, 1, 2, 11, 12, 22),
       StopStream,
       StartStream(),
@@ -706,6 +744,7 @@ class KafkaSourceSuite extends KafkaSourceTest {
       .format("memory")
       .outputMode("append")
       .queryName("kafkaColumnTypes")
+      .trigger(defaultTrigger)
       .start()
     query.processAllAvailable()
     val rows = spark.table("kafkaColumnTypes").collect()
@@ -723,47 +762,6 @@ class KafkaSourceSuite extends KafkaSourceTest {
     query.stop()
   }
 
-  test("KafkaSource with watermark") {
-    val now = System.currentTimeMillis()
-    val topic = newTopic()
-    testUtils.createTopic(newTopic(), partitions = 1)
-    testUtils.sendMessages(topic, Array(1).map(_.toString))
-
-    val kafka = spark
-      .readStream
-      .format("kafka")
-      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
-      .option("kafka.metadata.max.age.ms", "1")
-      .option("startingOffsets", s"earliest")
-      .option("subscribe", topic)
-      .load()
-
-    val windowedAggregation = kafka
-      .withWatermark("timestamp", "10 seconds")
-      .groupBy(window($"timestamp", "5 seconds") as 'window)
-      .agg(count("*") as 'count)
-      .select($"window".getField("start") as 'window, $"count")
-
-    val query = windowedAggregation
-      .writeStream
-      .format("memory")
-      .outputMode("complete")
-      .queryName("kafkaWatermark")
-      .start()
-    query.processAllAvailable()
-    val rows = spark.table("kafkaWatermark").collect()
-    assert(rows.length === 1, s"Unexpected results: ${rows.toList}")
-    val row = rows(0)
-    // We cannot check the exact window start time as it depands on the time that messages were
-    // inserted by the producer. So here we just use a low bound to make sure the internal
-    // conversion works.
-    assert(
-      row.getAs[java.sql.Timestamp]("window").getTime >= now - 5 * 1000,
-      s"Unexpected results: $row")
-    assert(row.getAs[Int]("count") === 1, s"Unexpected results: $row")
-    query.stop()
-  }
-
   private def testFromLatestOffsets(
       topic: String,
       addPartitions: Boolean,
@@ -800,9 +798,7 @@ class KafkaSourceSuite extends KafkaSourceTest {
       AddKafkaData(Set(topic), 7, 8),
       CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9),
       AssertOnQuery("Add partitions") { query: StreamExecution =>
-        if (addPartitions) {
-          testUtils.addPartitions(topic, 10)
-        }
+        if (addPartitions) setTopicPartitions(topic, 10, query)
         true
       },
       AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16),
@@ -843,9 +839,7 @@ class KafkaSourceSuite extends KafkaSourceTest {
       StartStream(),
       CheckAnswer(2, 3, 4, 5, 6, 7, 8, 9),
       AssertOnQuery("Add partitions") { query: StreamExecution =>
-        if (addPartitions) {
-          testUtils.addPartitions(topic, 10)
-        }
+        if (addPartitions) setTopicPartitions(topic, 10, query)
         true
       },
       AddKafkaData(Set(topic), 9, 10, 11, 12, 13, 14, 15, 16),
@@ -977,20 +971,8 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared
     }
   }
 
-  test("stress test for failOnDataLoss=false") {
-    val reader = spark
-      .readStream
-      .format("kafka")
-      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
-      .option("kafka.metadata.max.age.ms", "1")
-      .option("subscribePattern", "failOnDataLoss.*")
-      .option("startingOffsets", "earliest")
-      .option("failOnDataLoss", "false")
-      .option("fetchOffset.retryIntervalMs", "3000")
-    val kafka = reader.load()
-      .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
-      .as[(String, String)]
-    val query = kafka.map(kv => kv._2.toInt).writeStream.foreach(new ForeachWriter[Int] {
+  protected def startStream(ds: Dataset[Int]) = {
+    ds.writeStream.foreach(new ForeachWriter[Int] {
 
       override def open(partitionId: Long, version: Long): Boolean = {
         true
@@ -1004,6 +986,22 @@ class KafkaSourceStressForDontFailOnDataLossSuite extends StreamTest with Shared
       override def close(errorOrNull: Throwable): Unit = {
       }
     }).start()
+  }
+
+  test("stress test for failOnDataLoss=false") {
+    val reader = spark
+      .readStream
+      .format("kafka")
+      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+      .option("kafka.metadata.max.age.ms", "1")
+      .option("subscribePattern", "failOnDataLoss.*")
+      .option("startingOffsets", "earliest")
+      .option("failOnDataLoss", "false")
+      .option("fetchOffset.retryIntervalMs", "3000")
+    val kafka = reader.load()
+      .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
+      .as[(String, String)]
+    val query = startStream(kafka.map(kv => kv._2.toInt))
 
     val testTime = 1.minutes
     val startTime = System.currentTimeMillis()

http://git-wip-us.apache.org/repos/asf/spark/blob/6f7aaed8/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index e8d683a..b714a46 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -191,6 +191,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
           ds = ds.asInstanceOf[DataSourceV2],
           conf = sparkSession.sessionState.conf)).asJava)
 
+      // Streaming also uses the data source V2 API. So it may be that the data source implements
+      // v2, but has no v2 implementation for batch reads. In that case, we fall back to loading
+      // the dataframe as a v1 source.
       val reader = (ds, userSpecifiedSchema) match {
         case (ds: ReadSupportWithSchema, Some(schema)) =>
           ds.createReader(schema, options)
@@ -208,23 +211,30 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
           }
           reader
 
-        case _ =>
-          throw new AnalysisException(s"$cls does not support data reading.")
+        case _ => null // fall back to v1
       }
 
-      Dataset.ofRows(sparkSession, DataSourceV2Relation(reader))
+      if (reader == null) {
+        loadV1Source(paths: _*)
+      } else {
+        Dataset.ofRows(sparkSession, DataSourceV2Relation(reader))
+      }
     } else {
-      // Code path for data source v1.
-      sparkSession.baseRelationToDataFrame(
-        DataSource.apply(
-          sparkSession,
-          paths = paths,
-          userSpecifiedSchema = userSpecifiedSchema,
-          className = source,
-          options = extraOptions.toMap).resolveRelation())
+      loadV1Source(paths: _*)
     }
   }
 
+  private def loadV1Source(paths: String*) = {
+    // Code path for data source v1.
+    sparkSession.baseRelationToDataFrame(
+      DataSource.apply(
+        sparkSession,
+        paths = paths,
+        userSpecifiedSchema = userSpecifiedSchema,
+        className = source,
+        options = extraOptions.toMap).resolveRelation())
+  }
+
   /**
    * Construct a `DataFrame` representing the database table accessible via JDBC URL
    * url named table and connection properties.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org