You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by td...@apache.org on 2018/02/07 23:22:59 UTC

spark git commit: [SPARK-23092][SQL] Migrate MemoryStream to DataSourceV2 APIs

Repository: spark
Updated Branches:
  refs/heads/master 9841ae031 -> 30295bf5a


[SPARK-23092][SQL] Migrate MemoryStream to DataSourceV2 APIs

## What changes were proposed in this pull request?

This PR migrates the MemoryStream to DataSourceV2 APIs.

One additional change is in the reported keys in StreamingQueryProgress.durationMs. "getOffset" and "getBatch" replaced with "setOffsetRange" and "getEndOffset" as tracking these make more sense. Unit tests changed accordingly.

## How was this patch tested?
Existing unit tests, few updated unit tests.

Author: Tathagata Das <ta...@gmail.com>
Author: Burak Yavuz <br...@gmail.com>

Closes #20445 from tdas/SPARK-23092.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/30295bf5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/30295bf5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/30295bf5

Branch: refs/heads/master
Commit: 30295bf5a6754d0ae43334f7bf00e7a29ed0f1af
Parents: 9841ae0
Author: Tathagata Das <ta...@gmail.com>
Authored: Wed Feb 7 15:22:53 2018 -0800
Committer: Tathagata Das <ta...@gmail.com>
Committed: Wed Feb 7 15:22:53 2018 -0800

----------------------------------------------------------------------
 .../sql/execution/streaming/LongOffset.scala    |   4 +-
 .../streaming/MicroBatchExecution.scala         |  27 ++--
 .../spark/sql/execution/streaming/memory.scala  | 132 +++++++++++--------
 .../streaming/sources/RateStreamSourceV2.scala  |   2 +-
 .../execution/streaming/ForeachSinkSuite.scala  |  55 +++-----
 .../spark/sql/streaming/StreamSuite.scala       |   8 +-
 .../apache/spark/sql/streaming/StreamTest.scala |   2 +-
 .../streaming/StreamingQueryListenerSuite.scala |   5 +-
 .../sql/streaming/StreamingQuerySuite.scala     |  70 ++++++----
 9 files changed, 171 insertions(+), 134 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/30295bf5/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala
index 5f0b195..3ff5b86 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala
@@ -17,10 +17,12 @@
 
 package org.apache.spark.sql.execution.streaming
 
+import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2}
+
 /**
  * A simple offset for sources that produce a single linear stream of data.
  */
-case class LongOffset(offset: Long) extends Offset {
+case class LongOffset(offset: Long) extends OffsetV2 {
 
   override val json = offset.toString
 

http://git-wip-us.apache.org/repos/asf/spark/blob/30295bf5/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
index d9aa857..045d2b4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
@@ -270,16 +270,17 @@ class MicroBatchExecution(
             }
           case s: MicroBatchReader =>
             updateStatusMessage(s"Getting offsets from $s")
-            reportTimeTaken("getOffset") {
-            // Once v1 streaming source execution is gone, we can refactor this away.
-            // For now, we set the range here to get the source to infer the available end offset,
-            // get that offset, and then set the range again when we later execute.
-            s.setOffsetRange(
-              toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))),
-              Optional.empty())
-
-              (s, Some(s.getEndOffset))
+            reportTimeTaken("setOffsetRange") {
+              // Once v1 streaming source execution is gone, we can refactor this away.
+              // For now, we set the range here to get the source to infer the available end offset,
+              // get that offset, and then set the range again when we later execute.
+              s.setOffsetRange(
+                toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))),
+                Optional.empty())
             }
+
+            val currentOffset = reportTimeTaken("getEndOffset") { s.getEndOffset() }
+            (s, Option(currentOffset))
         }.toMap
         availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get)
 
@@ -401,10 +402,14 @@ class MicroBatchExecution(
         case (reader: MicroBatchReader, available)
           if committedOffsets.get(reader).map(_ != available).getOrElse(true) =>
           val current = committedOffsets.get(reader).map(off => reader.deserializeOffset(off.json))
+          val availableV2: OffsetV2 = available match {
+            case v1: SerializedOffset => reader.deserializeOffset(v1.json)
+            case v2: OffsetV2 => v2
+          }
           reader.setOffsetRange(
             toJava(current),
-            Optional.of(available.asInstanceOf[OffsetV2]))
-          logDebug(s"Retrieving data from $reader: $current -> $available")
+            Optional.of(availableV2))
+          logDebug(s"Retrieving data from $reader: $current -> $availableV2")
           Some(reader ->
             new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader))
         case _ => None

http://git-wip-us.apache.org/repos/asf/spark/blob/30295bf5/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index 509a69d..352d4ce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -17,21 +17,23 @@
 
 package org.apache.spark.sql.execution.streaming
 
+import java.{util => ju}
+import java.util.Optional
 import java.util.concurrent.atomic.AtomicInteger
 import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.JavaConverters._
-import scala.collection.mutable
 import scala.collection.mutable.{ArrayBuffer, ListBuffer}
 import scala.util.control.NonFatal
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.encoders.encoderFor
-import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, Statistics}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
+import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
-import org.apache.spark.sql.execution.SQLExecution
+import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow}
+import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2}
 import org.apache.spark.sql.streaming.OutputMode
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.Utils
@@ -51,9 +53,10 @@ object MemoryStream {
  * available.
  */
 case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
-    extends Source with Logging {
+    extends MicroBatchReader with SupportsScanUnsafeRow with Logging {
   protected val encoder = encoderFor[A]
-  protected val logicalPlan = StreamingExecutionRelation(this, sqlContext.sparkSession)
+  private val attributes = encoder.schema.toAttributes
+  protected val logicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession)
   protected val output = logicalPlan.output
 
   /**
@@ -61,11 +64,17 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
    * Stored in a ListBuffer to facilitate removing committed batches.
    */
   @GuardedBy("this")
-  protected val batches = new ListBuffer[Dataset[A]]
+  protected val batches = new ListBuffer[Array[UnsafeRow]]
 
   @GuardedBy("this")
   protected var currentOffset: LongOffset = new LongOffset(-1)
 
+  @GuardedBy("this")
+  private var startOffset = new LongOffset(-1)
+
+  @GuardedBy("this")
+  private var endOffset = new LongOffset(-1)
+
   /**
    * Last offset that was discarded, or -1 if no commits have occurred. Note that the value
    * -1 is used in calculations below and isn't just an arbitrary constant.
@@ -73,8 +82,6 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
   @GuardedBy("this")
   protected var lastOffsetCommitted : LongOffset = new LongOffset(-1)
 
-  def schema: StructType = encoder.schema
-
   def toDS(): Dataset[A] = {
     Dataset(sqlContext.sparkSession, logicalPlan)
   }
@@ -88,72 +95,69 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
   }
 
   def addData(data: TraversableOnce[A]): Offset = {
-    val encoded = data.toVector.map(d => encoder.toRow(d).copy())
-    val plan = new LocalRelation(schema.toAttributes, encoded, isStreaming = true)
-    val ds = Dataset[A](sqlContext.sparkSession, plan)
-    logDebug(s"Adding ds: $ds")
+    val objects = data.toSeq
+    val rows = objects.iterator.map(d => encoder.toRow(d).copy().asInstanceOf[UnsafeRow]).toArray
+    logDebug(s"Adding: $objects")
     this.synchronized {
       currentOffset = currentOffset + 1
-      batches += ds
+      batches += rows
       currentOffset
     }
   }
 
   override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]"
 
-  override def getOffset: Option[Offset] = synchronized {
-    if (currentOffset.offset == -1) {
-      None
-    } else {
-      Some(currentOffset)
+  override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = {
+    synchronized {
+      startOffset = start.orElse(LongOffset(-1)).asInstanceOf[LongOffset]
+      endOffset = end.orElse(currentOffset).asInstanceOf[LongOffset]
     }
   }
 
-  override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
-    // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal)
-    val startOrdinal =
-      start.flatMap(LongOffset.convert).getOrElse(LongOffset(-1)).offset.toInt + 1
-    val endOrdinal = LongOffset.convert(end).getOrElse(LongOffset(-1)).offset.toInt + 1
-
-    // Internal buffer only holds the batches after lastCommittedOffset.
-    val newBlocks = synchronized {
-      val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1
-      val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1
-      assert(sliceStart <= sliceEnd, s"sliceStart: $sliceStart sliceEnd: $sliceEnd")
-      batches.slice(sliceStart, sliceEnd)
-    }
+  override def readSchema(): StructType = encoder.schema
 
-    if (newBlocks.isEmpty) {
-      return sqlContext.internalCreateDataFrame(
-        sqlContext.sparkContext.emptyRDD, schema, isStreaming = true)
-    }
+  override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong)
+
+  override def getStartOffset: OffsetV2 = synchronized {
+    if (startOffset.offset == -1) null else startOffset
+  }
 
-    logDebug(generateDebugString(newBlocks, startOrdinal, endOrdinal))
+  override def getEndOffset: OffsetV2 = synchronized {
+    if (endOffset.offset == -1) null else endOffset
+  }
 
-    newBlocks
-      .map(_.toDF())
-      .reduceOption(_ union _)
-      .getOrElse {
-        sys.error("No data selected!")
+  override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = {
+    synchronized {
+      // Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal)
+      val startOrdinal = startOffset.offset.toInt + 1
+      val endOrdinal = endOffset.offset.toInt + 1
+
+      // Internal buffer only holds the batches after lastCommittedOffset.
+      val newBlocks = synchronized {
+        val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1
+        val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1
+        assert(sliceStart <= sliceEnd, s"sliceStart: $sliceStart sliceEnd: $sliceEnd")
+        batches.slice(sliceStart, sliceEnd)
       }
+
+      logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal))
+
+      newBlocks.map { block =>
+        new MemoryStreamDataReaderFactory(block).asInstanceOf[DataReaderFactory[UnsafeRow]]
+      }.asJava
+    }
   }
 
   private def generateDebugString(
-      blocks: TraversableOnce[Dataset[A]],
+      rows: Seq[UnsafeRow],
       startOrdinal: Int,
       endOrdinal: Int): String = {
-    val originalUnsupportedCheck =
-      sqlContext.getConf("spark.sql.streaming.unsupportedOperationCheck")
-    try {
-      sqlContext.setConf("spark.sql.streaming.unsupportedOperationCheck", "false")
-      s"MemoryBatch [$startOrdinal, $endOrdinal]: " +
-          s"${blocks.flatMap(_.collect()).mkString(", ")}"
-    } finally {
-      sqlContext.setConf("spark.sql.streaming.unsupportedOperationCheck", originalUnsupportedCheck)
-    }
+    val fromRow = encoder.resolveAndBind().fromRow _
+    s"MemoryBatch [$startOrdinal, $endOrdinal]: " +
+        s"${rows.map(row => fromRow(row)).mkString(", ")}"
   }
 
-  override def commit(end: Offset): Unit = synchronized {
+  override def commit(end: OffsetV2): Unit = synchronized {
     def check(newOffset: LongOffset): Unit = {
       val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt
 
@@ -176,11 +180,33 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
 
   def reset(): Unit = synchronized {
     batches.clear()
+    startOffset = LongOffset(-1)
+    endOffset = LongOffset(-1)
     currentOffset = new LongOffset(-1)
     lastOffsetCommitted = new LongOffset(-1)
   }
 }
 
+
+class MemoryStreamDataReaderFactory(records: Array[UnsafeRow])
+  extends DataReaderFactory[UnsafeRow] {
+  override def createDataReader(): DataReader[UnsafeRow] = {
+    new DataReader[UnsafeRow] {
+      private var currentIndex = -1
+
+      override def next(): Boolean = {
+        // Return true as long as the new index is in the array.
+        currentIndex += 1
+        currentIndex < records.length
+      }
+
+      override def get(): UnsafeRow = records(currentIndex)
+
+      override def close(): Unit = {}
+    }
+  }
+}
+
 /**
  * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit
  * tests and does not provide durability.

http://git-wip-us.apache.org/repos/asf/spark/blob/30295bf5/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala
index 1315885..077a255 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamSourceV2.scala
@@ -151,7 +151,7 @@ case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends DataReaderFactor
 }
 
 class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] {
-  var currentIndex = -1
+  private var currentIndex = -1
 
   override def next(): Boolean = {
     // Return true as long as the new index is in the seq.

http://git-wip-us.apache.org/repos/asf/spark/blob/30295bf5/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
index 41434e6..b249dd4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
@@ -46,49 +46,34 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf
         .foreach(new TestForeachWriter())
         .start()
 
-      // -- batch 0 ---------------------------------------
-      input.addData(1, 2, 3, 4)
-      query.processAllAvailable()
+      def verifyOutput(expectedVersion: Int, expectedData: Seq[Int]): Unit = {
+        import ForeachSinkSuite._
 
-      var expectedEventsForPartition0 = Seq(
-        ForeachSinkSuite.Open(partition = 0, version = 0),
-        ForeachSinkSuite.Process(value = 2),
-        ForeachSinkSuite.Process(value = 3),
-        ForeachSinkSuite.Close(None)
-      )
-      var expectedEventsForPartition1 = Seq(
-        ForeachSinkSuite.Open(partition = 1, version = 0),
-        ForeachSinkSuite.Process(value = 1),
-        ForeachSinkSuite.Process(value = 4),
-        ForeachSinkSuite.Close(None)
-      )
+        val events = ForeachSinkSuite.allEvents()
+        assert(events.size === 2) // one seq of events for each of the 2 partitions
 
-      var allEvents = ForeachSinkSuite.allEvents()
-      assert(allEvents.size === 2)
-      assert(allEvents.toSet === Set(expectedEventsForPartition0, expectedEventsForPartition1))
+        // Verify both seq of events have an Open event as the first event
+        assert(events.map(_.head).toSet === Set(0, 1).map(p => Open(p, expectedVersion)))
+
+        // Verify all the Process event correspond to the expected data
+        val allProcessEvents = events.flatMap(_.filter(_.isInstanceOf[Process[_]]))
+        assert(allProcessEvents.toSet === expectedData.map { data => Process(data) }.toSet)
+
+        // Verify both seq of events have a Close event as the last event
+        assert(events.map(_.last).toSet === Set(Close(None), Close(None)))
+      }
 
+      // -- batch 0 ---------------------------------------
       ForeachSinkSuite.clear()
+      input.addData(1, 2, 3, 4)
+      query.processAllAvailable()
+      verifyOutput(expectedVersion = 0, expectedData = 1 to 4)
 
       // -- batch 1 ---------------------------------------
+      ForeachSinkSuite.clear()
       input.addData(5, 6, 7, 8)
       query.processAllAvailable()
-
-      expectedEventsForPartition0 = Seq(
-        ForeachSinkSuite.Open(partition = 0, version = 1),
-        ForeachSinkSuite.Process(value = 5),
-        ForeachSinkSuite.Process(value = 7),
-        ForeachSinkSuite.Close(None)
-      )
-      expectedEventsForPartition1 = Seq(
-        ForeachSinkSuite.Open(partition = 1, version = 1),
-        ForeachSinkSuite.Process(value = 6),
-        ForeachSinkSuite.Process(value = 8),
-        ForeachSinkSuite.Close(None)
-      )
-
-      allEvents = ForeachSinkSuite.allEvents()
-      assert(allEvents.size === 2)
-      assert(allEvents.toSet === Set(expectedEventsForPartition0, expectedEventsForPartition1))
+      verifyOutput(expectedVersion = 1, expectedData = 5 to 8)
 
       query.stop()
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/30295bf5/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index c65e5d3..d1a0483 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -492,16 +492,16 @@ class StreamSuite extends StreamTest {
 
       val explainWithoutExtended = q.explainInternal(false)
       // `extended = false` only displays the physical plan.
-      assert("LocalRelation".r.findAllMatchIn(explainWithoutExtended).size === 0)
-      assert("LocalTableScan".r.findAllMatchIn(explainWithoutExtended).size === 1)
+      assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithoutExtended).size === 0)
+      assert("DataSourceV2Scan".r.findAllMatchIn(explainWithoutExtended).size === 1)
       // Use "StateStoreRestore" to verify that it does output a streaming physical plan
       assert(explainWithoutExtended.contains("StateStoreRestore"))
 
       val explainWithExtended = q.explainInternal(true)
       // `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical
       // plan.
-      assert("LocalRelation".r.findAllMatchIn(explainWithExtended).size === 3)
-      assert("LocalTableScan".r.findAllMatchIn(explainWithExtended).size === 1)
+      assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithExtended).size === 3)
+      assert("DataSourceV2Scan".r.findAllMatchIn(explainWithExtended).size === 1)
       // Use "StateStoreRestore" to verify that it does output a streaming physical plan
       assert(explainWithExtended.contains("StateStoreRestore"))
     } finally {

http://git-wip-us.apache.org/repos/asf/spark/blob/30295bf5/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index d643356..37fe595 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -120,7 +120,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
   case class AddDataMemory[A](source: MemoryStream[A], data: Seq[A]) extends AddData {
     override def toString: String = s"AddData to $source: ${data.mkString(",")}"
 
-    override def addData(query: Option[StreamExecution]): (Source, Offset) = {
+    override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = {
       (source, source.addData(data))
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/30295bf5/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala
index 79d6519..b96f2bc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala
@@ -33,6 +33,7 @@ import org.apache.spark.scheduler._
 import org.apache.spark.sql.{Encoder, SparkSession}
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2}
 import org.apache.spark.sql.streaming.StreamingQueryListener._
 import org.apache.spark.sql.streaming.util.StreamManualClock
 import org.apache.spark.util.JsonProtocol
@@ -298,9 +299,9 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
       try {
         val input = new MemoryStream[Int](0, sqlContext) {
           @volatile var numTriggers = 0
-          override def getOffset: Option[Offset] = {
+          override def getEndOffset: OffsetV2 = {
             numTriggers += 1
-            super.getOffset
+            super.getEndOffset
           }
         }
         val clock = new StreamManualClock()

http://git-wip-us.apache.org/repos/asf/spark/blob/30295bf5/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
index 76201c6..3f9aa0d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
@@ -17,25 +17,27 @@
 
 package org.apache.spark.sql.streaming
 
+import java.{util => ju}
+import java.util.Optional
 import java.util.concurrent.CountDownLatch
 
 import org.apache.commons.lang3.RandomStringUtils
-import org.mockito.Mockito._
 import org.scalactic.TolerantNumerics
 import org.scalatest.BeforeAndAfter
-import org.scalatest.concurrent.Eventually._
 import org.scalatest.concurrent.PatienceConfiguration.Timeout
 import org.scalatest.mockito.MockitoSugar
 
 import org.apache.spark.SparkException
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.sources.v2.reader.DataReaderFactory
+import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2}
 import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock}
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.util.ManualClock
 
 class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging with MockitoSugar {
 
@@ -206,19 +208,29 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
 
     /** Custom MemoryStream that waits for manual clock to reach a time */
     val inputData = new MemoryStream[Int](0, sqlContext) {
-      // getOffset should take 50 ms the first time it is called
-      override def getOffset: Option[Offset] = {
-        val offset = super.getOffset
-        if (offset.nonEmpty) {
-          clock.waitTillTime(1050)
+
+      private def dataAdded: Boolean = currentOffset.offset != -1
+
+      // setOffsetRange should take 50 ms the first time it is called after data is added
+      override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = {
+        synchronized {
+          if (dataAdded) clock.waitTillTime(1050)
+          super.setOffsetRange(start, end)
         }
-        offset
+      }
+
+      // getEndOffset should take 100 ms the first time it is called after data is added
+      override def getEndOffset(): OffsetV2 = synchronized {
+        if (dataAdded) clock.waitTillTime(1150)
+        super.getEndOffset()
       }
 
       // getBatch should take 100 ms the first time it is called
-      override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
-        if (start.isEmpty) clock.waitTillTime(1150)
-        super.getBatch(start, end)
+      override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = {
+        synchronized {
+          clock.waitTillTime(1350)
+          super.createUnsafeRowReaderFactories()
+        }
       }
     }
 
@@ -258,39 +270,44 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
       AssertOnQuery(_.status.message === "Waiting for next trigger"),
       AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0),
 
-      // Test status and progress while offset is being fetched
+      // Test status and progress when setOffsetRange is being called
       AddData(inputData, 1, 2),
-      AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on getOffset
+      AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on setOffsetRange
       AssertStreamExecThreadIsWaitingForTime(1050),
       AssertOnQuery(_.status.isDataAvailable === false),
       AssertOnQuery(_.status.isTriggerActive === true),
       AssertOnQuery(_.status.message.startsWith("Getting offsets from")),
       AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0),
 
-      // Test status and progress while batch is being fetched
-      AdvanceManualClock(50), // time = 1050 to unblock getOffset
+      AdvanceManualClock(50), // time = 1050 to unblock setOffsetRange
       AssertClockTime(1050),
-      AssertStreamExecThreadIsWaitingForTime(1150),      // will block on getBatch that needs 1150
+      AssertStreamExecThreadIsWaitingForTime(1150), // will block on getEndOffset that needs 1150
+      AssertOnQuery(_.status.isDataAvailable === false),
+      AssertOnQuery(_.status.isTriggerActive === true),
+      AssertOnQuery(_.status.message.startsWith("Getting offsets from")),
+      AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0),
+
+      AdvanceManualClock(100), // time = 1150 to unblock getEndOffset
+      AssertClockTime(1150),
+      AssertStreamExecThreadIsWaitingForTime(1350), // will block on createReadTasks that needs 1350
       AssertOnQuery(_.status.isDataAvailable === true),
       AssertOnQuery(_.status.isTriggerActive === true),
       AssertOnQuery(_.status.message === "Processing new data"),
       AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0),
 
-      // Test status and progress while batch is being processed
-      AdvanceManualClock(100), // time = 1150 to unblock getBatch
-      AssertClockTime(1150),
-      AssertStreamExecThreadIsWaitingForTime(1500), // will block in Spark job that needs 1500
+      AdvanceManualClock(200), // time = 1350 to unblock createReadTasks
+      AssertClockTime(1350),
+      AssertStreamExecThreadIsWaitingForTime(1500), // will block on map task that needs 1500
       AssertOnQuery(_.status.isDataAvailable === true),
       AssertOnQuery(_.status.isTriggerActive === true),
       AssertOnQuery(_.status.message === "Processing new data"),
       AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0),
 
       // Test status and progress while batch processing has completed
-      AssertOnQuery { _ => clock.getTimeMillis() === 1150 },
-      AdvanceManualClock(350), // time = 1500 to unblock job
+      AdvanceManualClock(150), // time = 1500 to unblock map task
       AssertClockTime(1500),
       CheckAnswer(2),
-      AssertStreamExecThreadIsWaitingForTime(2000),
+      AssertStreamExecThreadIsWaitingForTime(2000),  // will block until the next trigger
       AssertOnQuery(_.status.isDataAvailable === true),
       AssertOnQuery(_.status.isTriggerActive === false),
       AssertOnQuery(_.status.message === "Waiting for next trigger"),
@@ -307,10 +324,11 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
         assert(progress.numInputRows === 2)
         assert(progress.processedRowsPerSecond === 4.0)
 
-        assert(progress.durationMs.get("getOffset") === 50)
-        assert(progress.durationMs.get("getBatch") === 100)
+        assert(progress.durationMs.get("setOffsetRange") === 50)
+        assert(progress.durationMs.get("getEndOffset") === 100)
         assert(progress.durationMs.get("queryPlanning") === 0)
         assert(progress.durationMs.get("walCommit") === 0)
+        assert(progress.durationMs.get("addBatch") === 350)
         assert(progress.durationMs.get("triggerExecution") === 500)
 
         assert(progress.sources.length === 1)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org