You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by td...@apache.org on 2018/02/23 20:41:03 UTC

spark git commit: [SPARK-23408][SS] Synchronize successive AddData actions in Streaming*JoinSuite

Repository: spark
Updated Branches:
  refs/heads/master 049f243c5 -> 855ce13d0


[SPARK-23408][SS] Synchronize successive AddData actions in Streaming*JoinSuite

**The best way to review this PR is to ignore whitespace/indent changes. Use this link - https://github.com/apache/spark/pull/20650/files?w=1**

## What changes were proposed in this pull request?

The stream-stream join tests add data to multiple sources and expect it all to show up in the next batch. But there's a race condition; the new batch might trigger when only one of the AddData actions has been reached.

Prior attempt to solve this issue by jose-torres in #20646 attempted to simultaneously synchronize on all memory sources together when consecutive AddData was found in the actions. However, this carries the risk of deadlock as well as unintended modification of stress tests (see the above PR for a detailed explanation). Instead, this PR attempts the following.

- A new action called `StreamProgressBlockedActions` that allows multiple actions to be executed while the streaming query is blocked from making progress. This allows data to be added to multiple sources that are made visible simultaneously in the next batch.
- An alias of `StreamProgressBlockedActions` called `MultiAddData` is explicitly used in the `Streaming*JoinSuites` to add data to two memory sources simultaneously.

This should avoid unintentional modification of the stress tests (or any other test for that matter) while making sure that the flaky tests are deterministic.

## How was this patch tested?
Modified test cases in `Streaming*JoinSuites` where there are consecutive `AddData` actions.

Author: Tathagata Das <ta...@gmail.com>

Closes #20650 from tdas/SPARK-23408.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/855ce13d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/855ce13d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/855ce13d

Branch: refs/heads/master
Commit: 855ce13d045569b7b16fdc7eee9c981f4ff3a545
Parents: 049f243
Author: Tathagata Das <ta...@gmail.com>
Authored: Fri Feb 23 12:40:58 2018 -0800
Committer: Tathagata Das <ta...@gmail.com>
Committed: Fri Feb 23 12:40:58 2018 -0800

----------------------------------------------------------------------
 .../streaming/MicroBatchExecution.scala         |  10 +
 .../apache/spark/sql/streaming/StreamTest.scala | 472 ++++++++++---------
 .../sql/streaming/StreamingJoinSuite.scala      |  54 +--
 3 files changed, 284 insertions(+), 252 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/855ce13d/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
index 8465501..6bd0397 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
@@ -504,6 +504,16 @@ class MicroBatchExecution(
     }
   }
 
+  /** Execute a function while locking the stream from making an progress */
+  private[sql] def withProgressLocked(f: => Unit): Unit = {
+    awaitProgressLock.lock()
+    try {
+      f
+    } finally {
+      awaitProgressLock.unlock()
+    }
+  }
+
   private def toJava(scalaOption: Option[OffsetV2]): Optional[OffsetV2] = {
     Optional.ofNullable(scalaOption.orNull)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/855ce13d/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index 159dd0e..08f722e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -102,6 +102,19 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
       AddDataMemory(source, data)
   }
 
+  /**
+   * Adds data to multiple memory streams such that all the data will be made visible in the
+   * same batch. This is applicable only to MicroBatchExecution, as this coordination cannot be
+   * performed at the driver in ContinuousExecutions.
+   */
+  object MultiAddData {
+    def apply[A]
+      (source1: MemoryStream[A], data1: A*)(source2: MemoryStream[A], data2: A*): StreamAction = {
+      val actions = Seq(AddDataMemory(source1, data1), AddDataMemory(source2, data2))
+      StreamProgressLockedActions(actions, desc = actions.mkString("[ ", " | ", " ]"))
+    }
+  }
+
   /** A trait that can be extended when testing a source. */
   trait AddData extends StreamAction {
     /**
@@ -217,6 +230,19 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
       s"ExpectFailure[${causeClass.getName}, isFatalError: $isFatalError]"
   }
 
+  /**
+   * Performs multiple actions while locking the stream from progressing.
+   * This is applicable only to MicroBatchExecution, as progress of ContinuousExecution
+   * cannot be controlled from the driver.
+   */
+  case class StreamProgressLockedActions(actions: Seq[StreamAction], desc: String = null)
+    extends StreamAction {
+
+    override def toString(): String = {
+      if (desc != null) desc else super.toString
+    }
+  }
+
   /** Assert that a body is true */
   class Assert(condition: => Boolean, val message: String = "") extends StreamAction {
     def run(): Unit = { Assertions.assert(condition) }
@@ -295,6 +321,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
     val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for
     val sink = if (useV2Sink) new MemorySinkV2 else new MemorySink(stream.schema, outputMode)
     val resetConfValues = mutable.Map[String, Option[String]]()
+    val defaultCheckpointLocation =
+      Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+    var manualClockExpectedTime = -1L
 
     @volatile
     var streamThreadDeathCause: Throwable = null
@@ -425,243 +454,254 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
       }
     }
 
-    var manualClockExpectedTime = -1L
-    val defaultCheckpointLocation =
-      Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
-    try {
-      startedTest.foreach { action =>
-        logInfo(s"Processing test stream action: $action")
-        action match {
-          case StartStream(trigger, triggerClock, additionalConfs, checkpointLocation) =>
-            verify(currentStream == null, "stream already running")
-            verify(triggerClock.isInstanceOf[SystemClock]
-              || triggerClock.isInstanceOf[StreamManualClock],
-              "Use either SystemClock or StreamManualClock to start the stream")
-            if (triggerClock.isInstanceOf[StreamManualClock]) {
-              manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis()
+    def executeAction(action: StreamAction): Unit = {
+      logInfo(s"Processing test stream action: $action")
+      action match {
+        case StartStream(trigger, triggerClock, additionalConfs, checkpointLocation) =>
+          verify(currentStream == null, "stream already running")
+          verify(triggerClock.isInstanceOf[SystemClock]
+            || triggerClock.isInstanceOf[StreamManualClock],
+            "Use either SystemClock or StreamManualClock to start the stream")
+          if (triggerClock.isInstanceOf[StreamManualClock]) {
+            manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis()
+          }
+          val metadataRoot = Option(checkpointLocation).getOrElse(defaultCheckpointLocation)
+
+          additionalConfs.foreach(pair => {
+            val value =
+              if (sparkSession.conf.contains(pair._1)) {
+                Some(sparkSession.conf.get(pair._1))
+              } else None
+            resetConfValues(pair._1) = value
+            sparkSession.conf.set(pair._1, pair._2)
+          })
+
+          lastStream = currentStream
+          currentStream =
+            sparkSession
+              .streams
+              .startQuery(
+                None,
+                Some(metadataRoot),
+                stream,
+                Map(),
+                sink,
+                outputMode,
+                trigger = trigger,
+                triggerClock = triggerClock)
+              .asInstanceOf[StreamingQueryWrapper]
+              .streamingQuery
+          // Wait until the initialization finishes, because some tests need to use `logicalPlan`
+          // after starting the query.
+          try {
+            currentStream.awaitInitialization(streamingTimeout.toMillis)
+            currentStream match {
+              case s: ContinuousExecution => eventually("IncrementalExecution was not created") {
+                assert(s.lastExecution != null)
+              }
+              case _ =>
             }
-            val metadataRoot = Option(checkpointLocation).getOrElse(defaultCheckpointLocation)
+          } catch {
+            case _: StreamingQueryException =>
+              // Ignore the exception. `StopStream` or `ExpectFailure` will catch it as well.
+          }
 
-            additionalConfs.foreach(pair => {
-              val value =
-                if (sparkSession.conf.contains(pair._1)) {
-                  Some(sparkSession.conf.get(pair._1))
-                } else None
-              resetConfValues(pair._1) = value
-              sparkSession.conf.set(pair._1, pair._2)
-            })
+        case AdvanceManualClock(timeToAdd) =>
+          verify(currentStream != null,
+                 "can not advance manual clock when a stream is not running")
+          verify(currentStream.triggerClock.isInstanceOf[StreamManualClock],
+                 s"can not advance clock of type ${currentStream.triggerClock.getClass}")
+          val clock = currentStream.triggerClock.asInstanceOf[StreamManualClock]
+          assert(manualClockExpectedTime >= 0)
+
+          // Make sure we don't advance ManualClock too early. See SPARK-16002.
+          eventually("StreamManualClock has not yet entered the waiting state") {
+            assert(clock.isStreamWaitingAt(manualClockExpectedTime))
+          }
 
+          clock.advance(timeToAdd)
+          manualClockExpectedTime += timeToAdd
+          verify(clock.getTimeMillis() === manualClockExpectedTime,
+            s"Unexpected clock time after updating: " +
+              s"expecting $manualClockExpectedTime, current ${clock.getTimeMillis()}")
+
+        case StopStream =>
+          verify(currentStream != null, "can not stop a stream that is not running")
+          try failAfter(streamingTimeout) {
+            currentStream.stop()
+            verify(!currentStream.queryExecutionThread.isAlive,
+              s"microbatch thread not stopped")
+            verify(!currentStream.isActive,
+              "query.isActive() is false even after stopping")
+            verify(currentStream.exception.isEmpty,
+              s"query.exception() is not empty after clean stop: " +
+                currentStream.exception.map(_.toString()).getOrElse(""))
+          } catch {
+            case _: InterruptedException =>
+            case e: org.scalatest.exceptions.TestFailedDueToTimeoutException =>
+              failTest(
+                "Timed out while stopping and waiting for microbatchthread to terminate.", e)
+            case t: Throwable =>
+              failTest("Error while stopping stream", t)
+          } finally {
             lastStream = currentStream
-            currentStream =
-              sparkSession
-                .streams
-                .startQuery(
-                  None,
-                  Some(metadataRoot),
-                  stream,
-                  Map(),
-                  sink,
-                  outputMode,
-                  trigger = trigger,
-                  triggerClock = triggerClock)
-                .asInstanceOf[StreamingQueryWrapper]
-                .streamingQuery
-            // Wait until the initialization finishes, because some tests need to use `logicalPlan`
-            // after starting the query.
-            try {
-              currentStream.awaitInitialization(streamingTimeout.toMillis)
-              currentStream match {
-                case s: ContinuousExecution => eventually("IncrementalExecution was not created") {
-                  assert(s.lastExecution != null)
-                }
-                case _ =>
-              }
-            } catch {
-              case _: StreamingQueryException =>
-                // Ignore the exception. `StopStream` or `ExpectFailure` will catch it as well.
-            }
+            currentStream = null
+          }
 
-          case AdvanceManualClock(timeToAdd) =>
-            verify(currentStream != null,
-                   "can not advance manual clock when a stream is not running")
-            verify(currentStream.triggerClock.isInstanceOf[StreamManualClock],
-                   s"can not advance clock of type ${currentStream.triggerClock.getClass}")
-            val clock = currentStream.triggerClock.asInstanceOf[StreamManualClock]
-            assert(manualClockExpectedTime >= 0)
-
-            // Make sure we don't advance ManualClock too early. See SPARK-16002.
-            eventually("StreamManualClock has not yet entered the waiting state") {
-              assert(clock.isStreamWaitingAt(manualClockExpectedTime))
+        case ef: ExpectFailure[_] =>
+          verify(currentStream != null, "can not expect failure when stream is not running")
+          try failAfter(streamingTimeout) {
+            val thrownException = intercept[StreamingQueryException] {
+              currentStream.awaitTermination()
             }
-
-            clock.advance(timeToAdd)
-            manualClockExpectedTime += timeToAdd
-            verify(clock.getTimeMillis() === manualClockExpectedTime,
-              s"Unexpected clock time after updating: " +
-                s"expecting $manualClockExpectedTime, current ${clock.getTimeMillis()}")
-
-          case StopStream =>
-            verify(currentStream != null, "can not stop a stream that is not running")
-            try failAfter(streamingTimeout) {
-              currentStream.stop()
-              verify(!currentStream.queryExecutionThread.isAlive,
-                s"microbatch thread not stopped")
-              verify(!currentStream.isActive,
-                "query.isActive() is false even after stopping")
-              verify(currentStream.exception.isEmpty,
-                s"query.exception() is not empty after clean stop: " +
-                  currentStream.exception.map(_.toString()).getOrElse(""))
-            } catch {
-              case _: InterruptedException =>
-              case e: org.scalatest.exceptions.TestFailedDueToTimeoutException =>
-                failTest(
-                  "Timed out while stopping and waiting for microbatchthread to terminate.", e)
-              case t: Throwable =>
-                failTest("Error while stopping stream", t)
-            } finally {
-              lastStream = currentStream
-              currentStream = null
+            eventually("microbatch thread not stopped after termination with failure") {
+              assert(!currentStream.queryExecutionThread.isAlive)
             }
+            verify(currentStream.exception === Some(thrownException),
+              s"incorrect exception returned by query.exception()")
+
+            val exception = currentStream.exception.get
+            verify(exception.cause.getClass === ef.causeClass,
+              "incorrect cause in exception returned by query.exception()\n" +
+                s"\tExpected: ${ef.causeClass}\n\tReturned: ${exception.cause.getClass}")
+            if (ef.isFatalError) {
+              // This is a fatal error, `streamThreadDeathCause` should be set to this error in
+              // UncaughtExceptionHandler.
+              verify(streamThreadDeathCause != null &&
+                streamThreadDeathCause.getClass === ef.causeClass,
+                "UncaughtExceptionHandler didn't receive the correct error\n" +
+                  s"\tExpected: ${ef.causeClass}\n\tReturned: $streamThreadDeathCause")
+              streamThreadDeathCause = null
+            }
+            ef.assertFailure(exception.getCause)
+          } catch {
+            case _: InterruptedException =>
+            case e: org.scalatest.exceptions.TestFailedDueToTimeoutException =>
+              failTest("Timed out while waiting for failure", e)
+            case t: Throwable =>
+              failTest("Error while checking stream failure", t)
+          } finally {
+            lastStream = currentStream
+            currentStream = null
+          }
 
-          case ef: ExpectFailure[_] =>
-            verify(currentStream != null, "can not expect failure when stream is not running")
-            try failAfter(streamingTimeout) {
-              val thrownException = intercept[StreamingQueryException] {
-                currentStream.awaitTermination()
-              }
-              eventually("microbatch thread not stopped after termination with failure") {
-                assert(!currentStream.queryExecutionThread.isAlive)
+        case a: AssertOnQuery =>
+          verify(currentStream != null || lastStream != null,
+            "cannot assert when no stream has been started")
+          val streamToAssert = Option(currentStream).getOrElse(lastStream)
+          try {
+            verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}")
+          } catch {
+            case NonFatal(e) =>
+              failTest(s"Assert on query failed: ${a.message}", e)
+          }
+
+        case a: Assert =>
+          val streamToAssert = Option(currentStream).getOrElse(lastStream)
+          verify({ a.run(); true }, s"Assert failed: ${a.message}")
+
+        case a: AddData =>
+          try {
+
+            // If the query is running with manual clock, then wait for the stream execution
+            // thread to start waiting for the clock to increment. This is needed so that we
+            // are adding data when there is no trigger that is active. This would ensure that
+            // the data gets deterministically added to the next batch triggered after the manual
+            // clock is incremented in following AdvanceManualClock. This avoid race conditions
+            // between the test thread and the stream execution thread in tests using manual
+            // clock.
+            if (currentStream != null &&
+                currentStream.triggerClock.isInstanceOf[StreamManualClock]) {
+              val clock = currentStream.triggerClock.asInstanceOf[StreamManualClock]
+              eventually("Error while synchronizing with manual clock before adding data") {
+                if (currentStream.isActive) {
+                  assert(clock.isStreamWaitingAt(clock.getTimeMillis()))
+                }
               }
-              verify(currentStream.exception === Some(thrownException),
-                s"incorrect exception returned by query.exception()")
-
-              val exception = currentStream.exception.get
-              verify(exception.cause.getClass === ef.causeClass,
-                "incorrect cause in exception returned by query.exception()\n" +
-                  s"\tExpected: ${ef.causeClass}\n\tReturned: ${exception.cause.getClass}")
-              if (ef.isFatalError) {
-                // This is a fatal error, `streamThreadDeathCause` should be set to this error in
-                // UncaughtExceptionHandler.
-                verify(streamThreadDeathCause != null &&
-                  streamThreadDeathCause.getClass === ef.causeClass,
-                  "UncaughtExceptionHandler didn't receive the correct error\n" +
-                    s"\tExpected: ${ef.causeClass}\n\tReturned: $streamThreadDeathCause")
-                streamThreadDeathCause = null
+              if (!currentStream.isActive) {
+                failTest("Query terminated while synchronizing with manual clock")
               }
-              ef.assertFailure(exception.getCause)
-            } catch {
-              case _: InterruptedException =>
-              case e: org.scalatest.exceptions.TestFailedDueToTimeoutException =>
-                failTest("Timed out while waiting for failure", e)
-              case t: Throwable =>
-                failTest("Error while checking stream failure", t)
-            } finally {
-              lastStream = currentStream
-              currentStream = null
             }
-
-          case a: AssertOnQuery =>
-            verify(currentStream != null || lastStream != null,
-              "cannot assert when no stream has been started")
-            val streamToAssert = Option(currentStream).getOrElse(lastStream)
-            try {
-              verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}")
-            } catch {
-              case NonFatal(e) =>
-                failTest(s"Assert on query failed: ${a.message}", e)
+            // Add data
+            val queryToUse = Option(currentStream).orElse(Option(lastStream))
+            val (source, offset) = a.addData(queryToUse)
+
+            def findSourceIndex(plan: LogicalPlan): Option[Int] = {
+              plan
+                .collect {
+                  case StreamingExecutionRelation(s, _) => s
+                  case StreamingDataSourceV2Relation(_, r) => r
+                }
+                .zipWithIndex
+                .find(_._1 == source)
+                .map(_._2)
             }
 
-          case a: Assert =>
-            val streamToAssert = Option(currentStream).getOrElse(lastStream)
-            verify({ a.run(); true }, s"Assert failed: ${a.message}")
-
-          case a: AddData =>
-            try {
-
-              // If the query is running with manual clock, then wait for the stream execution
-              // thread to start waiting for the clock to increment. This is needed so that we
-              // are adding data when there is no trigger that is active. This would ensure that
-              // the data gets deterministically added to the next batch triggered after the manual
-              // clock is incremented in following AdvanceManualClock. This avoid race conditions
-              // between the test thread and the stream execution thread in tests using manual
-              // clock.
-              if (currentStream != null &&
-                  currentStream.triggerClock.isInstanceOf[StreamManualClock]) {
-                val clock = currentStream.triggerClock.asInstanceOf[StreamManualClock]
-                eventually("Error while synchronizing with manual clock before adding data") {
-                  if (currentStream.isActive) {
-                    assert(clock.isStreamWaitingAt(clock.getTimeMillis()))
-                  }
+            // Try to find the index of the source to which data was added. Either get the index
+            // from the current active query or the original input logical plan.
+            val sourceIndex =
+              queryToUse.flatMap { query =>
+                findSourceIndex(query.logicalPlan)
+              }.orElse {
+                findSourceIndex(stream.logicalPlan)
+              }.orElse {
+                queryToUse.flatMap { q =>
+                  findSourceIndex(q.lastExecution.logical)
                 }
-                if (!currentStream.isActive) {
-                  failTest("Query terminated while synchronizing with manual clock")
-                }
-              }
-              // Add data
-              val queryToUse = Option(currentStream).orElse(Option(lastStream))
-              val (source, offset) = a.addData(queryToUse)
-
-              def findSourceIndex(plan: LogicalPlan): Option[Int] = {
-                plan
-                  .collect {
-                    case StreamingExecutionRelation(s, _) => s
-                    case StreamingDataSourceV2Relation(_, r) => r
-                  }
-                  .zipWithIndex
-                  .find(_._1 == source)
-                  .map(_._2)
+              }.getOrElse {
+                throw new IllegalArgumentException(
+                  "Could not find index of the source to which data was added")
               }
 
-              // Try to find the index of the source to which data was added. Either get the index
-              // from the current active query or the original input logical plan.
-              val sourceIndex =
-                queryToUse.flatMap { query =>
-                  findSourceIndex(query.logicalPlan)
-                }.orElse {
-                  findSourceIndex(stream.logicalPlan)
-                }.orElse {
-                  queryToUse.flatMap { q =>
-                    findSourceIndex(q.lastExecution.logical)
-                  }
-                }.getOrElse {
-                  throw new IllegalArgumentException(
-                    "Could not find index of the source to which data was added")
-                }
+            // Store the expected offset of added data to wait for it later
+            awaiting.put(sourceIndex, offset)
+          } catch {
+            case NonFatal(e) =>
+              failTest("Error adding data", e)
+          }
 
-              // Store the expected offset of added data to wait for it later
-              awaiting.put(sourceIndex, offset)
-            } catch {
-              case NonFatal(e) =>
-                failTest("Error adding data", e)
-            }
+        case e: ExternalAction =>
+          e.runAction()
 
-          case e: ExternalAction =>
-            e.runAction()
+        case CheckAnswerRows(expectedAnswer, lastOnly, isSorted) =>
+          val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
+          QueryTest.sameRows(expectedAnswer, sparkAnswer, isSorted).foreach {
+            error => failTest(error)
+          }
 
-          case CheckAnswerRows(expectedAnswer, lastOnly, isSorted) =>
-            val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
-            QueryTest.sameRows(expectedAnswer, sparkAnswer, isSorted).foreach {
-              error => failTest(error)
-            }
+        case CheckAnswerRowsContains(expectedAnswer, lastOnly) =>
+          val sparkAnswer = currentStream match {
+            case null => fetchStreamAnswer(lastStream, lastOnly)
+            case s => fetchStreamAnswer(s, lastOnly)
+          }
+          QueryTest.includesRows(expectedAnswer, sparkAnswer).foreach {
+            error => failTest(error)
+          }
 
-          case CheckAnswerRowsContains(expectedAnswer, lastOnly) =>
-            val sparkAnswer = currentStream match {
-              case null => fetchStreamAnswer(lastStream, lastOnly)
-              case s => fetchStreamAnswer(s, lastOnly)
-            }
-            QueryTest.includesRows(expectedAnswer, sparkAnswer).foreach {
-              error => failTest(error)
-            }
+        case CheckAnswerRowsByFunc(globalCheckFunction, lastOnly) =>
+          val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
+          try {
+            globalCheckFunction(sparkAnswer)
+          } catch {
+            case e: Throwable => failTest(e.toString)
+          }
+      }
+      pos += 1
+    }
 
-          case CheckAnswerRowsByFunc(globalCheckFunction, lastOnly) =>
-            val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
-            try {
-              globalCheckFunction(sparkAnswer)
-            } catch {
-              case e: Throwable => failTest(e.toString)
-            }
-        }
-        pos += 1
+    try {
+      startedTest.foreach {
+        case StreamProgressLockedActions(actns, _) =>
+          // Perform actions while holding the stream from progressing
+          assert(currentStream != null,
+            s"Cannot perform stream-progress-locked actions $actns when query is not active")
+          assert(currentStream.isInstanceOf[MicroBatchExecution],
+            s"Cannot perform stream-progress-locked actions on non-microbatch queries")
+          currentStream.asInstanceOf[MicroBatchExecution].withProgressLocked {
+            actns.foreach(executeAction)
+          }
+
+        case action: StreamAction => executeAction(action)
       }
       if (streamThreadDeathCause != null) {
         failTest("Stream Thread Died", streamThreadDeathCause)

http://git-wip-us.apache.org/repos/asf/spark/blob/855ce13d/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
index 92087f6..11bdd13 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
@@ -462,15 +462,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
         .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue)
 
     testStream(joined)(
-      AddData(leftInput, 1, 2, 3),
-      AddData(rightInput, 3, 4, 5),
+      MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5),
       // The left rows with leftValue <= 4 should generate their outer join row now and
       // not get added to the state.
       CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, 2, null), Row(2, 10, 4, null)),
       assertNumStateRows(total = 4, updated = 4),
       // We shouldn't get more outer join rows when the watermark advances.
-      AddData(leftInput, 20),
-      AddData(rightInput, 21),
+      MultiAddData(leftInput, 20)(rightInput, 21),
       CheckLastBatch(),
       AddData(rightInput, 20),
       CheckLastBatch((20, 30, 40, "60"))
@@ -493,15 +491,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
       .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue)
 
     testStream(joined)(
-      AddData(leftInput, 3, 4, 5),
-      AddData(rightInput, 1, 2, 3),
+      MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3),
       // The right rows with value <= 7 should never be added to the state.
       CheckLastBatch(Row(3, 10, 6, "9")),
       assertNumStateRows(total = 4, updated = 4),
       // When the watermark advances, we get the outer join rows just as we would if they
       // were added but didn't match the full join condition.
-      AddData(leftInput, 20),
-      AddData(rightInput, 21),
+      MultiAddData(leftInput, 20)(rightInput, 21),
       CheckLastBatch(),
       AddData(rightInput, 20),
       CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, 8, null), Row(5, 10, 10, null))
@@ -524,15 +520,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
       .select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue)
 
     testStream(joined)(
-      AddData(leftInput, 1, 2, 3),
-      AddData(rightInput, 3, 4, 5),
+      MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5),
       // The left rows with value <= 4 should never be added to the state.
       CheckLastBatch(Row(3, 10, 6, "9")),
       assertNumStateRows(total = 4, updated = 4),
       // When the watermark advances, we get the outer join rows just as we would if they
       // were added but didn't match the full join condition.
-      AddData(leftInput, 20),
-      AddData(rightInput, 21),
+      MultiAddData(leftInput, 20)(rightInput, 21),
       CheckLastBatch(),
       AddData(rightInput, 20),
       CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, null, "12"), Row(5, 10, null, "15"))
@@ -555,15 +549,13 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
       .select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue)
 
     testStream(joined)(
-      AddData(leftInput, 3, 4, 5),
-      AddData(rightInput, 1, 2, 3),
+      MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3),
       // The right rows with rightValue <= 7 should generate their outer join row now and
       // not get added to the state.
       CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, null, "3"), Row(2, 10, null, "6")),
       assertNumStateRows(total = 4, updated = 4),
       // We shouldn't get more outer join rows when the watermark advances.
-      AddData(leftInput, 20),
-      AddData(rightInput, 21),
+      MultiAddData(leftInput, 20)(rightInput, 21),
       CheckLastBatch(),
       AddData(rightInput, 20),
       CheckLastBatch((20, 30, 40, "60"))
@@ -575,13 +567,11 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
 
     testStream(joined)(
       // Test inner part of the join.
-      AddData(leftInput, 1, 2, 3, 4, 5),
-      AddData(rightInput, 3, 4, 5, 6, 7),
+      MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7),
       CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)),
       // Old state doesn't get dropped until the batch *after* it gets introduced, so the
       // nulls won't show up until the next batch after the watermark advances.
-      AddData(leftInput, 21),
-      AddData(rightInput, 22),
+      MultiAddData(leftInput, 21)(rightInput, 22),
       CheckLastBatch(),
       assertNumStateRows(total = 12, updated = 2),
       AddData(leftInput, 22),
@@ -595,13 +585,11 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
 
     testStream(joined)(
       // Test inner part of the join.
-      AddData(leftInput, 1, 2, 3, 4, 5),
-      AddData(rightInput, 3, 4, 5, 6, 7),
+      MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7),
       CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)),
       // Old state doesn't get dropped until the batch *after* it gets introduced, so the
       // nulls won't show up until the next batch after the watermark advances.
-      AddData(leftInput, 21),
-      AddData(rightInput, 22),
+      MultiAddData(leftInput, 21)(rightInput, 22),
       CheckLastBatch(),
       assertNumStateRows(total = 12, updated = 2),
       AddData(leftInput, 22),
@@ -676,11 +664,9 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
 
     testStream(joined)(
       // leftValue <= 10 should generate outer join rows even though it matches right keys
-      AddData(leftInput, 1, 2, 3),
-      AddData(rightInput, 1, 2, 3),
+      MultiAddData(leftInput, 1, 2, 3)(rightInput, 1, 2, 3),
       CheckLastBatch(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)),
-      AddData(leftInput, 20),
-      AddData(rightInput, 21),
+      MultiAddData(leftInput, 20)(rightInput, 21),
       CheckLastBatch(),
       assertNumStateRows(total = 5, updated = 2),
       AddData(rightInput, 20),
@@ -688,22 +674,18 @@ class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with
         Row(20, 30, 40, 60)),
       assertNumStateRows(total = 3, updated = 1),
       // leftValue and rightValue both satisfying condition should not generate outer join rows
-      AddData(leftInput, 40, 41),
-      AddData(rightInput, 40, 41),
+      MultiAddData(leftInput, 40, 41)(rightInput, 40, 41),
       CheckLastBatch((40, 50, 80, 120), (41, 50, 82, 123)),
-      AddData(leftInput, 70),
-      AddData(rightInput, 71),
+      MultiAddData(leftInput, 70)(rightInput, 71),
       CheckLastBatch(),
       assertNumStateRows(total = 6, updated = 2),
       AddData(rightInput, 70),
       CheckLastBatch((70, 80, 140, 210)),
       assertNumStateRows(total = 3, updated = 1),
       // rightValue between 300 and 1000 should generate outer join rows even though it matches left
-      AddData(leftInput, 101, 102, 103),
-      AddData(rightInput, 101, 102, 103),
+      MultiAddData(leftInput, 101, 102, 103)(rightInput, 101, 102, 103),
       CheckLastBatch(),
-      AddData(leftInput, 1000),
-      AddData(rightInput, 1001),
+      MultiAddData(leftInput, 1000)(rightInput, 1001),
       CheckLastBatch(),
       assertNumStateRows(total = 8, updated = 2),
       AddData(rightInput, 1000),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org