You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ka...@apache.org on 2024/02/21 03:59:02 UTC

(spark) branch master updated: [SPARK-47052][SS] Separate state tracking variables from MicroBatchExecution/StreamExecution

This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new bffa92c838d6 [SPARK-47052][SS] Separate state tracking variables from MicroBatchExecution/StreamExecution
bffa92c838d6 is described below

commit bffa92c838d6650249a6e71bb0ef8189cf970383
Author: Jerry Peng <je...@databricks.com>
AuthorDate: Wed Feb 21 12:58:48 2024 +0900

    [SPARK-47052][SS] Separate state tracking variables from MicroBatchExecution/StreamExecution
    
    ### What changes were proposed in this pull request?
    
    To improve code clarity and maintainability, I propose that we move all the variables that track mutable state and metrics for a streaming query into a separate class.  With this refactor, it would be easy to track and find all the mutable state a microbatch can have.
    
    ### Why are the changes needed?
    
    To improve code clarity and maintainability.  All the state and metrics that is needed for the execution lifecycle of a microbatch is consolidated into one class.  If we decide to modify or add additional state to a streaming query, it will be easier to determine 1) where to add it 2) what existing state are there.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    ### How was this patch tested?
    
    Existing tests should suffice
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #45109 from jerrypeng/SPARK-47052.
    
    Authored-by: Jerry Peng <je...@databricks.com>
    Signed-off-by: Jungtaek Lim <ka...@gmail.com>
---
 .../sql/execution/streaming/AsyncLogPurge.scala    |  11 +-
 .../AsyncProgressTrackingMicroBatchExecution.scala |  30 +-
 .../execution/streaming/MicroBatchExecution.scala  | 422 ++++++++++-------
 .../sql/execution/streaming/ProgressReporter.scala | 521 +++++++++++++--------
 .../sql/execution/streaming/StreamExecution.scala  | 112 +++--
 .../streaming/StreamExecutionContext.scala         | 233 +++++++++
 .../sql/execution/streaming/TriggerExecutor.scala  |  24 +-
 .../streaming/continuous/ContinuousExecution.scala |  56 ++-
 .../streaming/ProcessingTimeExecutorSuite.scala    |   6 +-
 9 files changed, 945 insertions(+), 470 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala
index b3729dbc7b45..aa393211a1c1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala
@@ -29,11 +29,8 @@ import org.apache.spark.util.ThreadUtils
  */
 trait AsyncLogPurge extends Logging {
 
-  protected var currentBatchId: Long
-
   protected val minLogEntriesToMaintain: Int
 
-
   protected[sql] val errorNotifier: ErrorNotifier
 
   protected val sparkSession: SparkSession
@@ -47,15 +44,11 @@ trait AsyncLogPurge extends Logging {
 
   protected lazy val useAsyncPurge: Boolean = sparkSession.conf.get(SQLConf.ASYNC_LOG_PURGE)
 
-  protected def purgeAsync(): Unit = {
+  protected def purgeAsync(batchId: Long): Unit = {
     if (purgeRunning.compareAndSet(false, true)) {
-      // save local copy because currentBatchId may get updated.  There are not really
-      // any concurrency issues here in regards to calculating the purge threshold
-      // but for the sake of defensive coding lets make a copy
-      val currentBatchIdCopy: Long = currentBatchId
       asyncPurgeExecutorService.execute(() => {
         try {
-          purge(currentBatchIdCopy - minLogEntriesToMaintain)
+          purge(batchId - minLogEntriesToMaintain)
         } catch {
           case throwable: Throwable =>
             logError("Encountered error while performing async log purge", throwable)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala
index 206efb9a5450..ec24ec0fd335 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala
@@ -110,12 +110,12 @@ class AsyncProgressTrackingMicroBatchExecution(
     }
   }
 
-  override def markMicroBatchExecutionStart(): Unit = {
+  override def markMicroBatchExecutionStart(execCtx: MicroBatchExecutionContext): Unit = {
     // check if streaming query is stateful
     checkNotStatefulStreamingQuery
   }
 
-  override def cleanUpLastExecutedMicroBatch(): Unit = {
+  override def cleanUpLastExecutedMicroBatch(execCtx: MicroBatchExecutionContext): Unit = {
     // this is a no op for async progress tracking since we only want to commit sources only
     // after the offset WAL commit has be successfully written
   }
@@ -124,11 +124,11 @@ class AsyncProgressTrackingMicroBatchExecution(
    * Should not call super method as we need to do something completely different
    * in this method for async progress tracking
    */
-  override def markMicroBatchStart(): Unit = {
+  override def markMicroBatchStart(execCtx: MicroBatchExecutionContext): Unit = {
     // Because we are using a thread pool with only one thread, async writes to the offset log
     // are still written in a serial / in order fashion
     offsetLog
-      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .addAsync(execCtx.batchId, execCtx.endOffsets.toOffsetSeq(sources, execCtx.offsetSeqMetadata))
       .thenAccept(tuple => {
         val (batchId, persistedToDurableStorage) = tuple
         if (persistedToDurableStorage) {
@@ -157,7 +157,7 @@ class AsyncProgressTrackingMicroBatchExecution(
       })
       .exceptionally((th: Throwable) => {
         logError(s"Encountered error while performing" +
-          s" async offset write for batch ${currentBatchId}", th)
+          s" async offset write for batch ${execCtx.batchId}", th)
         errorNotifier.markError(th)
         return
       })
@@ -170,9 +170,9 @@ class AsyncProgressTrackingMicroBatchExecution(
     }
   }
 
-  override def markMicroBatchEnd(): Unit = {
-    watermarkTracker.updateWatermark(lastExecution.executedPlan)
-    reportTimeTaken("commitOffsets") {
+  override def markMicroBatchEnd(execCtx: MicroBatchExecutionContext): Unit = {
+    watermarkTracker.updateWatermark(execCtx.executionPlan.executedPlan)
+    execCtx.reportTimeTaken("commitOffsets") {
       // check if current batch there is a async write for the offset log is issued for this batch
       // if so, we should do the same for commit log.  However, if this is the first batch executed
       // in this run we should always persist to the commit log.  There can be situations in which
@@ -181,27 +181,27 @@ class AsyncProgressTrackingMicroBatchExecution(
       // and the commit log is 0, 2.  On restart we will re-process the data from batch 3 -> 5.
       // Batch 5 is already part of the offset log but we still need to write the entry to
       // the commit log
-      if (offsetLog.getAsyncOffsetWrite(currentBatchId).nonEmpty
+      if (offsetLog.getAsyncOffsetWrite(execCtx.batchId).nonEmpty
         || isFirstBatch) {
         isFirstBatch = false
 
         commitLog
-          .addAsync(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))
+          .addAsync(execCtx.batchId, CommitMetadata(watermarkTracker.currentWatermark))
           .exceptionally((th: Throwable) => {
             logError(s"Got exception during async write to commit log" +
-              s" for batch ${currentBatchId}", th)
+              s" for batch ${execCtx.batchId}", th)
             errorNotifier.markError(th)
             return
           })
       } else {
         if (!commitLog.addInMemory(
-          currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))) {
-          throw QueryExecutionErrors.concurrentStreamLogUpdate(currentBatchId)
+          execCtx.batchId, CommitMetadata(watermarkTracker.currentWatermark))) {
+          throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId)
         }
       }
-      offsetLog.removeAsyncOffsetWrite(currentBatchId)
+      offsetLog.removeAsyncOffsetWrite(execCtx.batchId)
     }
-    committedOffsets ++= availableOffsets
+    committedOffsets ++= execCtx.endOffsets
   }
 
   // need to look at the number of files on disk
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
index 8c98ad5c47dd..ae5a033538ab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
@@ -48,6 +48,39 @@ class MicroBatchExecution(
     sparkSession, plan.name, plan.resolvedCheckpointLocation, plan.inputQuery, plan.sink, trigger,
     triggerClock, plan.outputMode, plan.deleteCheckpointOnStop) with AsyncLogPurge {
 
+  /**
+   * Keeps track of the latest execution context
+   */
+  @volatile private var latestExecutionContext: StreamExecutionContext =
+    new MicroBatchExecutionContext(
+      id,
+      runId,
+      name,
+      triggerClock,
+      Seq.empty,
+      sink,
+      progressReporter,
+      -1,
+      sparkSession,
+      previousContext = None)
+
+  override def getLatestExecutionContext(): StreamExecutionContext = latestExecutionContext
+
+  /**
+   * We will only set the lastExecutionContext only if the batch id is larger than the batch id
+   * of the current latestExecutionContext.  This is done to make sure we will always tracking
+   * the latest execution context i.e. we will never set latestExecutionContext
+   * to a earlier / older batch.
+   * @param ctx
+   */
+  def setLatestExecutionContext(ctx: StreamExecutionContext): Unit = synchronized {
+    // make sure we are setting to the latest batch
+    if (latestExecutionContext.batchId <= ctx.batchId) {
+      latestExecutionContext = ctx
+    }
+  }
+
+
   protected[sql] val errorNotifier = new ErrorNotifier()
 
   @volatile protected var sources: Seq[SparkDataStream] = Seq.empty
@@ -231,12 +264,6 @@ class MicroBatchExecution(
     }
   }
 
-  /**
-   * Signifies whether current batch (i.e. for the batch `currentBatchId`) has been constructed
-   * (i.e. written to the offsetLog) and is ready for execution.
-   */
-  private var isCurrentBatchConstructed = false
-
   /**
    * Signals to the thread executing micro-batches that it should stop running after the next
    * batch. This method blocks until the thread stops running.
@@ -264,94 +291,115 @@ class MicroBatchExecution(
     logInfo(s"Async log purge executor pool for query ${prettyIdString} has been shutdown")
   }
 
-  /** Begins recording statistics about query progress for a given trigger. */
-  override protected def startTrigger(): Unit = {
-    super.startTrigger()
-    currentStatus = currentStatus.copy(isTriggerActive = true)
-  }
+  private def initializeExecution(
+      sparkSessionForStream: SparkSession): MicroBatchExecutionContext = {
+    AcceptsLatestSeenOffsetHandler.setLatestSeenOffsetOnSources(
+      offsetLog.getLatest().map(_._2), sources)
+
+    val execCtx = new MicroBatchExecutionContext(id, runId, name, triggerClock, sources, sink,
+      progressReporter, -1, sparkSession, None)
 
+    execCtx.offsetSeqMetadata =
+      OffsetSeqMetadata(batchWatermarkMs = 0, batchTimestampMs = 0, sparkSessionForStream.conf)
+    setLatestExecutionContext(execCtx)
+
+    populateStartOffsets(execCtx, sparkSessionForStream)
+    logInfo(s"Stream started from ${execCtx.startOffsets}")
+    execCtx
+  }
   /**
    * Repeatedly attempts to run batches as data arrives.
    */
   protected def runActivatedStream(sparkSessionForStream: SparkSession): Unit = {
 
+    // create the first batch to run
+    val execCtx = initializeExecution(sparkSessionForStream)
+    triggerExecutor.setNextBatch(execCtx)
+
     val noDataBatchesEnabled =
       sparkSessionForStream.sessionState.conf.streamingNoDataMicroBatchesEnabled
 
-    triggerExecutor.execute(() => {
-      if (isActive) {
+    triggerExecutor.execute(executeOneBatch(_, sparkSessionForStream, noDataBatchesEnabled))
+  }
 
-        // check if there are any previous errors and bubble up any existing async operations
-        errorNotifier.throwErrorIfExists()
+  private def executeOneBatch(
+      execCtx: MicroBatchExecutionContext,
+      sparkSessionForStream: SparkSession,
+      noDataBatchesEnabled: Boolean): Boolean = {
+    assert(execCtx != null)
 
-        var currentBatchHasNewData = false // Whether the current batch had new data
+    if (isActive) {
+      logDebug(s"Running batch with context: ${execCtx}")
+      setLatestExecutionContext(execCtx)
 
-        startTrigger()
+      // check if there are any previous errors and bubble up any existing async operations
+      errorNotifier.throwErrorIfExists()
 
-        reportTimeTaken("triggerExecution") {
-          // We'll do this initialization only once every start / restart
-          if (currentBatchId < 0) {
-            AcceptsLatestSeenOffsetHandler.setLatestSeenOffsetOnSources(
-              offsetLog.getLatest().map(_._2), sources)
-            populateStartOffsets(sparkSessionForStream)
-            logInfo(s"Stream started from $committedOffsets")
-          }
+      var currentBatchHasNewData = false // Whether the current batch had new data
 
-          // Set this before calling constructNextBatch() so any Spark jobs executed by sources
-          // while getting new data have the correct description
-          sparkSession.sparkContext.setJobDescription(getBatchDescriptionString)
-
-          // Try to construct the next batch. This will return true only if the next batch is
-          // ready and runnable. Note that the current batch may be runnable even without
-          // new data to process as `constructNextBatch` may decide to run a batch for
-          // state cleanup, etc. `isNewDataAvailable` will be updated to reflect whether new data
-          // is available or not.
-          if (!isCurrentBatchConstructed) {
-            isCurrentBatchConstructed = constructNextBatch(noDataBatchesEnabled)
-          }
+      execCtx.startTrigger()
 
-          // Record the trigger offset range for progress reporting *before* processing the batch
-          recordTriggerOffsets(
-            from = committedOffsets,
-            to = availableOffsets,
-            latest = latestOffsets)
-
-          // Remember whether the current batch has data or not. This will be required later
-          // for bookkeeping after running the batch, when `isNewDataAvailable` will have changed
-          // to false as the batch would have already processed the available data.
-          currentBatchHasNewData = isNewDataAvailable
-
-          currentStatus = currentStatus.copy(isDataAvailable = isNewDataAvailable)
-          if (isCurrentBatchConstructed) {
-            if (currentBatchHasNewData) updateStatusMessage("Processing new data")
-            else updateStatusMessage("No new data but cleaning up state")
-            runBatch(sparkSessionForStream)
-          } else {
-            updateStatusMessage("Waiting for data to arrive")
-          }
+      execCtx.reportTimeTaken("triggerExecution") {
+        // Set this before calling constructNextBatch() so any Spark jobs executed by sources
+        // while getting new data have the correct description
+        sparkSession.sparkContext.setJobDescription(getBatchDescriptionString)
+
+        // Try to construct the next batch. This will return true only if the next batch is
+        // ready and runnable. Note that the current batch may be runnable even without
+        // new data to process as `constructNextBatch` may decide to run a batch for
+        // state cleanup, etc. `isNewDataAvailable` will be updated to reflect whether new data
+        // is available or not.
+        if (!execCtx.isCurrentBatchConstructed) {
+          execCtx.isCurrentBatchConstructed = constructNextBatch(execCtx, noDataBatchesEnabled)
         }
 
-        // Must be outside reportTimeTaken so it is recorded
-        finishTrigger(currentBatchHasNewData, isCurrentBatchConstructed)
-
-        // Signal waiting threads. Note this must be after finishTrigger() to ensure all
-        // activities (progress generation, etc.) have completed before signaling.
-        withProgressLocked { awaitProgressLockCondition.signalAll() }
-
-        // If the current batch has been executed, then increment the batch id and reset flag.
-        // Otherwise, there was no data to execute the batch and sleep for some time
-        if (isCurrentBatchConstructed) {
-          currentBatchId += 1
-          isCurrentBatchConstructed = false
-        } else if (triggerExecutor.isInstanceOf[MultiBatchExecutor]) {
-          logInfo("Finished processing all available data for the trigger, terminating this " +
-            "Trigger.AvailableNow query")
-          state.set(TERMINATED)
-        } else Thread.sleep(pollingDelayMs)
+        // Record the trigger offset range for progress reporting *before* processing the batch
+        execCtx.recordTriggerOffsets(
+          from = execCtx.startOffsets,
+          to = execCtx.endOffsets,
+          latest = execCtx.latestOffsets)
+
+        // Remember whether the current batch has data or not. This will be required later
+        // for bookkeeping after running the batch, when `isNewDataAvailable` will have changed
+        // to false as the batch would have already processed the available data.
+        currentBatchHasNewData = isNewDataAvailable(execCtx)
+
+        execCtx.currentStatus
+          = execCtx.currentStatus.copy(isDataAvailable = isNewDataAvailable(execCtx))
+        if (execCtx.isCurrentBatchConstructed) {
+          if (currentBatchHasNewData) execCtx.updateStatusMessage("Processing new data")
+          else execCtx.updateStatusMessage("No new data but cleaning up state")
+          runBatch(execCtx, sparkSessionForStream)
+        } else {
+          execCtx.updateStatusMessage("Waiting for data to arrive")
+        }
+      }
+
+      execCtx.carryOverExecStatsOnLatestExecutedBatch()
+      // Must be outside reportTimeTaken so it is recorded
+      if (execCtx.isCurrentBatchConstructed) {
+        execCtx.finishTrigger(currentBatchHasNewData, execCtx.executionPlan, execCtx.batchId)
+      } else {
+        execCtx.finishNoExecutionTrigger(execCtx.batchId)
       }
-      updateStatusMessage("Waiting for next trigger")
-      isActive
-    })
+
+      // Signal waiting threads. Note this must be after finishTrigger() to ensure all
+      // activities (progress generation, etc.) have completed before signaling.
+      withProgressLocked { awaitProgressLockCondition.signalAll() }
+
+      // If the current batch has been executed, then increment the batch id and reset flag.
+      // Otherwise, there was no data to execute the batch and sleep for some time
+      if (execCtx.isCurrentBatchConstructed) {
+        triggerExecutor.setNextBatch(execCtx.getNextContext())
+        execCtx.onExecutionComplete()
+      } else if (triggerExecutor.isInstanceOf[MultiBatchExecutor]) {
+        logInfo("Finished processing all available data for the trigger, terminating this " +
+          "Trigger.AvailableNow query")
+        state.set(TERMINATED)
+      } else Thread.sleep(pollingDelayMs)
+    }
+    execCtx.updateStatusMessage("Waiting for next trigger")
+    isActive
   }
 
   /**
@@ -379,10 +427,11 @@ class MicroBatchExecution(
   /**
    * Populate the start offsets to start the execution at the current offsets stored in the sink
    * (i.e. avoid reprocessing data that we have already processed). This function must be called
-   * before any processing occurs and will populate the following fields:
-   *  - currentBatchId
-   *  - committedOffsets
-   *  - availableOffsets
+   * before any processing occurs and will populate the following fields in the execution context
+   * of this micro-batch
+   *  - batchId
+   *  - startOffset
+   *  - endOffsets
    *  The basic structure of this method is as follows:
    *
    *  Identify (from the offset log) the offsets used to run the last batch
@@ -398,24 +447,28 @@ class MicroBatchExecution(
    *    Identify a brand new batch
    *  DONE
    */
-  private def populateStartOffsets(sparkSessionToRunBatches: SparkSession): Unit = {
-    sinkCommitProgress = None
+  protected def populateStartOffsets(
+      execCtx: MicroBatchExecutionContext,
+      sparkSessionToRunBatches: SparkSession): Unit = {
+    execCtx.sinkCommitProgress = None
     offsetLog.getLatest() match {
       case Some((latestBatchId, nextOffsets)) =>
         /* First assume that we are re-executing the latest known batch
          * in the offset log */
-        currentBatchId = latestBatchId
-        isCurrentBatchConstructed = true
-        availableOffsets = nextOffsets.toStreamProgress(sources)
+        execCtx.batchId = latestBatchId
+        execCtx.isCurrentBatchConstructed = true
+        execCtx.endOffsets = nextOffsets.toStreamProgress(sources)
 
         // validate the integrity of offset log and get the previous offset from the offset log
         val secondLatestOffsets = validateOffsetLogAndGetPrevOffset(latestBatchId)
-        secondLatestOffsets.foreach(offset => committedOffsets = offset.toStreamProgress(sources))
+        secondLatestOffsets.foreach { offset =>
+          execCtx.startOffsets = offset.toStreamProgress(sources)
+        }
 
         // update offset metadata
         nextOffsets.metadata.foreach { metadata =>
           OffsetSeqMetadata.setSessionConf(metadata, sparkSessionToRunBatches.conf)
-          offsetSeqMetadata = OffsetSeqMetadata(
+          execCtx.offsetSeqMetadata = OffsetSeqMetadata(
             metadata.batchWatermarkMs, metadata.batchTimestampMs, sparkSessionToRunBatches.conf)
           watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf)
           watermarkTracker.setWatermark(metadata.batchWatermarkMs)
@@ -432,23 +485,23 @@ class MicroBatchExecution(
                * Make a call to getBatch using the offsets from previous batch.
                * because certain sources (e.g., KafkaSource) assume on restart the last
                * batch will be executed before getOffset is called again. */
-              availableOffsets.foreach {
+              execCtx.endOffsets.foreach {
                 case (source: Source, end: Offset) =>
-                  val start = committedOffsets.get(source).map(_.asInstanceOf[Offset])
+                  val start = execCtx.startOffsets.get(source).map(_.asInstanceOf[Offset])
                   source.getBatch(start, end)
                 case nonV1Tuple =>
                   // The V2 API does not have the same edge case requiring getBatch to be called
                   // here, so we do nothing here.
               }
-              currentBatchId = latestCommittedBatchId + 1
-              isCurrentBatchConstructed = false
-              committedOffsets ++= availableOffsets
+              execCtx.batchId = latestCommittedBatchId + 1
+              execCtx.isCurrentBatchConstructed = false
+              execCtx.startOffsets ++= execCtx.endOffsets
               watermarkTracker.setWatermark(
                 math.max(watermarkTracker.currentWatermark, commitMetadata.nextBatchWatermarkMs))
             } else if (latestCommittedBatchId == latestBatchId - 1) {
-              availableOffsets.foreach {
+              execCtx.endOffsets.foreach {
                 case (source: Source, end: Offset) =>
-                  val start = committedOffsets.get(source).map(_.asInstanceOf[Offset])
+                  val start = execCtx.startOffsets.get(source).map(_.asInstanceOf[Offset])
                   if (start.map(_ == end).getOrElse(true)) {
                     source.getBatch(start, end)
                   }
@@ -463,11 +516,13 @@ class MicroBatchExecution(
             }
           case None => logInfo("no commit log present")
         }
-        logInfo(s"Resuming at batch $currentBatchId with committed offsets " +
-          s"$committedOffsets and available offsets $availableOffsets")
+        // initialize committed offsets to start offsets of the most recent committed batch
+        committedOffsets = execCtx.startOffsets
+        logInfo(s"Resuming at batch ${execCtx.batchId} with committed offsets " +
+          s"${execCtx.startOffsets} and available offsets ${execCtx.endOffsets}")
       case None => // We are starting this stream for the first time.
         logInfo(s"Starting new streaming query.")
-        currentBatchId = 0
+        execCtx.batchId = 0
         watermarkTracker = WatermarkTracker(sparkSessionToRunBatches.conf)
     }
   }
@@ -475,10 +530,10 @@ class MicroBatchExecution(
   /**
    * Returns true if there is any new data available to be processed.
    */
-  private def isNewDataAvailable: Boolean = {
-    availableOffsets.exists {
+  private def isNewDataAvailable(execCtx: MicroBatchExecutionContext): Boolean = {
+    execCtx.endOffsets.exists {
       case (source, available) =>
-        committedOffsets
+        execCtx.startOffsets
           .get(source)
           .map(committed => committed != available)
           .getOrElse(true)
@@ -486,11 +541,13 @@ class MicroBatchExecution(
   }
 
   /**
-   * Get the startOffset from availableOffsets. This is to be used in
+   * Get the startOffset from endOffsets. This is to be used in
    * latestOffset(startOffset, readLimit)
    */
-  private def getStartOffset(dataStream: SparkDataStream): OffsetV2 = {
-    val startOffsetOpt = availableOffsets.get(dataStream)
+  private def getStartOffset(
+      execCtx: MicroBatchExecutionContext,
+      dataStream: SparkDataStream): OffsetV2 = {
+    val startOffsetOpt = execCtx.startOffsets.get(dataStream)
     dataStream match {
       case _: Source =>
         startOffsetOpt.orNull
@@ -514,35 +571,37 @@ class MicroBatchExecution(
    * - If either of the above is true, then construct the next batch by committing to the offset
    *   log that range of offsets that the next batch will process.
    */
-  private def constructNextBatch(noDataBatchesEnabled: Boolean): Boolean = withProgressLocked {
-    if (isCurrentBatchConstructed) return true
+  private def constructNextBatch(
+      execCtx: MicroBatchExecutionContext,
+      noDataBatchesEnabled: Boolean): Boolean = withProgressLocked {
+    if (execCtx.isCurrentBatchConstructed) return true
 
     // Generate a map from each unique source to the next available offset.
     val (nextOffsets, recentOffsets) = uniqueSources.toSeq.map {
       case (s: AvailableNowDataStreamWrapper, limit) =>
-        updateStatusMessage(s"Getting offsets from $s")
+        execCtx.updateStatusMessage(s"Getting offsets from $s")
         val originalSource = s.delegate
-        reportTimeTaken("latestOffset") {
-          val next = s.latestOffset(getStartOffset(originalSource), limit)
+        execCtx.reportTimeTaken("latestOffset") {
+          val next = s.latestOffset(getStartOffset(execCtx, originalSource), limit)
           val latest = s.reportLatestOffset()
           ((originalSource, Option(next)), (originalSource, Option(latest)))
         }
       case (s: SupportsAdmissionControl, limit) =>
-        updateStatusMessage(s"Getting offsets from $s")
-        reportTimeTaken("latestOffset") {
-          val next = s.latestOffset(getStartOffset(s), limit)
+        execCtx.updateStatusMessage(s"Getting offsets from $s")
+        execCtx.reportTimeTaken("latestOffset") {
+          val next = s.latestOffset(getStartOffset(execCtx, s), limit)
           val latest = s.reportLatestOffset()
           ((s, Option(next)), (s, Option(latest)))
         }
       case (s: Source, _) =>
-        updateStatusMessage(s"Getting offsets from $s")
-        reportTimeTaken("getOffset") {
+        execCtx.updateStatusMessage(s"Getting offsets from $s")
+        execCtx.reportTimeTaken("getOffset") {
           val offset = s.getOffset
           ((s, offset), (s, offset))
         }
       case (s: MicroBatchStream, _) =>
-        updateStatusMessage(s"Getting offsets from $s")
-        reportTimeTaken("latestOffset") {
+        execCtx.updateStatusMessage(s"Getting offsets from $s")
+        execCtx.reportTimeTaken("latestOffset") {
           val latest = s.latestOffset()
           ((s, Option(latest)), (s, Option(latest)))
         }
@@ -551,31 +610,34 @@ class MicroBatchExecution(
         throw new IllegalStateException(s"Unexpected source: $s")
     }.unzip
 
-    availableOffsets ++= nextOffsets.filter { case (_, o) => o.nonEmpty }
+    execCtx.endOffsets ++= nextOffsets.filter { case (_, o) => o.nonEmpty }
       .map(p => p._1 -> p._2.get).toMap
-    latestOffsets ++= recentOffsets.filter { case (_, o) => o.nonEmpty }
+    execCtx.latestOffsets ++= recentOffsets.filter { case (_, o) => o.nonEmpty }
       .map(p => p._1 -> p._2.get).toMap
 
     // Update the query metadata
-    offsetSeqMetadata = offsetSeqMetadata.copy(
+    execCtx.offsetSeqMetadata = execCtx.offsetSeqMetadata.copy(
       batchWatermarkMs = watermarkTracker.currentWatermark,
       batchTimestampMs = triggerClock.getTimeMillis())
 
     // Check whether next batch should be constructed
     val lastExecutionRequiresAnotherBatch = noDataBatchesEnabled &&
-      Option(lastExecution).exists(_.shouldRunAnotherBatch(offsetSeqMetadata))
-    val shouldConstructNextBatch = isNewDataAvailable || lastExecutionRequiresAnotherBatch
+      // need to check the execution plan of the previous batch
+      execCtx.previousContext.map { plan =>
+        Option(plan.executionPlan).exists(_.shouldRunAnotherBatch(execCtx.offsetSeqMetadata))
+      }.getOrElse(false)
+    val shouldConstructNextBatch = isNewDataAvailable(execCtx) || lastExecutionRequiresAnotherBatch
     logTrace(
       s"noDataBatchesEnabled = $noDataBatchesEnabled, " +
       s"lastExecutionRequiresAnotherBatch = $lastExecutionRequiresAnotherBatch, " +
-      s"isNewDataAvailable = $isNewDataAvailable, " +
+      s"isNewDataAvailable = ${isNewDataAvailable(execCtx)}, " +
       s"shouldConstructNextBatch = $shouldConstructNextBatch")
 
     if (shouldConstructNextBatch) {
       // Commit the next batch offset range to the offset log
-      updateStatusMessage("Writing offsets to log")
-      reportTimeTaken("walCommit") {
-        markMicroBatchStart()
+      execCtx.updateStatusMessage("Writing offsets to log")
+      execCtx.reportTimeTaken("walCommit") {
+        markMicroBatchStart(execCtx)
 
         // NOTE: The following code is correct because runStream() processes exactly one
         // batch at a time. If we add pipeline parallelism (multiple batches in flight at
@@ -583,15 +645,15 @@ class MicroBatchExecution(
 
         // Now that we've updated the scheduler's persistent checkpoint, it is safe for the
         // sources to discard data from the previous batch.
-        cleanUpLastExecutedMicroBatch()
+        cleanUpLastExecutedMicroBatch(execCtx)
 
         // It is now safe to discard the metadata beyond the minimum number to retain.
         // Note that purge is exclusive, i.e. it purges everything before the target ID.
-        if (minLogEntriesToMaintain < currentBatchId) {
+        if (minLogEntriesToMaintain < execCtx.batchId) {
           if (useAsyncPurge) {
-            purgeAsync()
+            purgeAsync(execCtx.batchId)
           } else {
-            purge(currentBatchId - minLogEntriesToMaintain)
+            purge(execCtx.batchId - minLogEntriesToMaintain)
           }
         }
       }
@@ -615,18 +677,20 @@ class MicroBatchExecution(
   }
 
   /**
-   * Processes any data available between `availableOffsets` and `committedOffsets`.
+   * Processes any data available between `endOffsets` and `startOffset`.
    * @param sparkSessionToRunBatch Isolated [[SparkSession]] to run this batch with.
    */
-  private def runBatch(sparkSessionToRunBatch: SparkSession): Unit = {
-    logDebug(s"Running batch $currentBatchId")
+  private def runBatch(
+      execCtx: MicroBatchExecutionContext,
+      sparkSessionToRunBatch: SparkSession): Unit = {
+    logDebug(s"Running batch ${execCtx.batchId}")
 
     // Request unprocessed data from all sources.
-    val mutableNewData = mutable.Map.empty ++ reportTimeTaken("getBatch") {
-      availableOffsets.flatMap {
+    val mutableNewData = mutable.Map.empty ++ execCtx.reportTimeTaken("getBatch") {
+      execCtx.endOffsets.flatMap {
         case (source: Source, available: Offset)
-          if committedOffsets.get(source).map(_ != available).getOrElse(true) =>
-          val current = committedOffsets.get(source).map(_.asInstanceOf[Offset])
+          if execCtx.startOffsets.get(source).map(_ != available).getOrElse(true) =>
+          val current = execCtx.startOffsets.get(source).map(_.asInstanceOf[Offset])
           val batch = source.getBatch(current, available)
           assert(batch.isStreaming,
             s"DataFrame returned by getBatch from $source did not have isStreaming=true\n" +
@@ -635,8 +699,8 @@ class MicroBatchExecution(
           Some(source -> batch.logicalPlan)
 
         case (stream: MicroBatchStream, available)
-          if committedOffsets.get(stream).map(_ != available).getOrElse(true) =>
-          val current = committedOffsets.get(stream).map {
+          if execCtx.startOffsets.get(stream).map(_ != available).getOrElse(true) =>
+          val current = execCtx.startOffsets.get(stream).map {
             off => stream.deserializeOffset(off.json)
           }
           val endOffset: OffsetV2 = available match {
@@ -716,7 +780,7 @@ class MicroBatchExecution(
           LocalRelation(r.output, isStreaming = true)
         }
     }
-    newData = mutableNewData.toMap
+    execCtx.newData = mutableNewData.toMap
     // Rewire the plan to use the new attributes that were returned by the source.
     val newAttributePlan = newBatchesPlan.transformAllExpressionsWithPruning(
       _.containsPattern(CURRENT_LIKE)) {
@@ -724,56 +788,56 @@ class MicroBatchExecution(
         // CurrentTimestamp is not TimeZoneAwareExpression while CurrentBatchTimestamp is.
         // Without TimeZoneId, CurrentBatchTimestamp is unresolved. Here, we use an explicit
         // dummy string to prevent UnresolvedException and to prevent to be used in the future.
-        CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs,
+        CurrentBatchTimestamp(execCtx.offsetSeqMetadata.batchTimestampMs,
           ct.dataType, Some("Dummy TimeZoneId"))
       case lt: LocalTimestamp =>
-        CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs,
+        CurrentBatchTimestamp(execCtx.offsetSeqMetadata.batchTimestampMs,
           lt.dataType, lt.timeZoneId)
       case cd: CurrentDate =>
-        CurrentBatchTimestamp(offsetSeqMetadata.batchTimestampMs,
+        CurrentBatchTimestamp(execCtx.offsetSeqMetadata.batchTimestampMs,
           cd.dataType, cd.timeZoneId)
     }
 
     val triggerLogicalPlan = sink match {
       case _: Sink =>
-        newAttributePlan.asInstanceOf[WriteToMicroBatchDataSourceV1].withNewBatchId(currentBatchId)
+        newAttributePlan.asInstanceOf[WriteToMicroBatchDataSourceV1].withNewBatchId(execCtx.batchId)
       case _: SupportsWrite =>
-        newAttributePlan.asInstanceOf[WriteToMicroBatchDataSource].withNewBatchId(currentBatchId)
+        newAttributePlan.asInstanceOf[WriteToMicroBatchDataSource].withNewBatchId(execCtx.batchId)
       case _ => throw new IllegalArgumentException(s"unknown sink type for $sink")
     }
 
     sparkSessionToRunBatch.sparkContext.setLocalProperty(
-      MicroBatchExecution.BATCH_ID_KEY, currentBatchId.toString)
+      MicroBatchExecution.BATCH_ID_KEY, execCtx.batchId.toString)
     sparkSessionToRunBatch.sparkContext.setLocalProperty(
       StreamExecution.IS_CONTINUOUS_PROCESSING, false.toString)
 
-    reportTimeTaken("queryPlanning") {
-      val isFirstBatch = lastExecution == null
-      lastExecution = new IncrementalExecution(
+    execCtx.reportTimeTaken("queryPlanning") {
+      execCtx.executionPlan = new IncrementalExecution(
         sparkSessionToRunBatch,
         triggerLogicalPlan,
         outputMode,
         checkpointFile("state"),
         id,
         runId,
-        currentBatchId,
-        offsetLog.offsetSeqMetadataForBatchId(currentBatchId - 1),
-        offsetSeqMetadata,
+        execCtx.batchId,
+        offsetLog.offsetSeqMetadataForBatchId(execCtx.batchId - 1),
+        execCtx.offsetSeqMetadata,
         watermarkPropagator,
-        isFirstBatch)
-      lastExecution.executedPlan // Force the lazy generation of execution plan
+        execCtx.previousContext.isEmpty)
+      execCtx.executionPlan.executedPlan // Force the lazy generation of execution plan
     }
 
-    markMicroBatchExecutionStart()
+    markMicroBatchExecutionStart(execCtx)
 
     val nextBatch =
-      new Dataset(lastExecution, ExpressionEncoder(lastExecution.analyzed.schema))
+      new Dataset(execCtx.executionPlan, ExpressionEncoder(execCtx.executionPlan.analyzed.schema))
 
-    val batchSinkProgress: Option[StreamWriterCommitProgress] = reportTimeTaken("addBatch") {
-      SQLExecution.withNewExecutionId(lastExecution) {
+    val batchSinkProgress: Option[StreamWriterCommitProgress] =
+      execCtx.reportTimeTaken("addBatch") {
+      SQLExecution.withNewExecutionId(execCtx.executionPlan) {
         sink match {
           case s: Sink =>
-            s.addBatch(currentBatchId, nextBatch)
+            s.addBatch(execCtx.batchId, nextBatch)
             // DSv2 write node has a mechanism to invalidate DSv2 relation, but there is no
             // corresponding one for DSv1. Given we have an information of catalog table for sink,
             // we can refresh the catalog table once the write has succeeded.
@@ -784,7 +848,7 @@ class MicroBatchExecution(
             // This doesn't accumulate any data - it just forces execution of the microbatch writer.
             nextBatch.collect()
         }
-        lastExecution.executedPlan match {
+        execCtx.executionPlan.executedPlan match {
           case w: WriteToDataSourceV2Exec => w.commitProgress
           case _ => None
         }
@@ -792,10 +856,10 @@ class MicroBatchExecution(
     }
 
     withProgressLocked {
-      sinkCommitProgress = batchSinkProgress
-      markMicroBatchEnd()
+      execCtx.sinkCommitProgress = batchSinkProgress
+      markMicroBatchEnd(execCtx)
     }
-    logDebug(s"Completed batch ${currentBatchId}")
+    logDebug(s"Completed batch ${execCtx.batchId}")
   }
 
 
@@ -803,46 +867,46 @@ class MicroBatchExecution(
    * Called at the start of the micro batch with given offsets. It takes care of offset
    * checkpointing to offset log and any microbatch startup tasks.
    */
-  protected def markMicroBatchStart(): Unit = {
-    if (!offsetLog.add(currentBatchId,
-      availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))) {
-      throw QueryExecutionErrors.concurrentStreamLogUpdate(currentBatchId)
+  protected def markMicroBatchStart(execCtx: MicroBatchExecutionContext): Unit = {
+    if (!offsetLog.add(execCtx.batchId,
+      execCtx.endOffsets.toOffsetSeq(sources, execCtx.offsetSeqMetadata))) {
+      throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId)
     }
 
-    logInfo(s"Committed offsets for batch $currentBatchId. " +
-      s"Metadata ${offsetSeqMetadata.toString}")
+    logInfo(s"Committed offsets for batch ${execCtx.batchId}. " +
+      s"Metadata ${execCtx.offsetSeqMetadata.toString}")
   }
 
   /**
    * Method called once after the planning is done and before the start of the microbatch execution.
    * It can be used to perform any pre-execution tasks.
    */
-  protected def markMicroBatchExecutionStart(): Unit = {}
+  protected def markMicroBatchExecutionStart(execCtx: MicroBatchExecutionContext): Unit = {}
 
   /**
    * Called after the microbatch has completed execution. It takes care of committing the offset
    * to commit log and other bookkeeping.
    */
-  protected def markMicroBatchEnd(): Unit = {
-    watermarkTracker.updateWatermark(lastExecution.executedPlan)
-    reportTimeTaken("commitOffsets") {
-      if (!commitLog.add(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))) {
-        throw QueryExecutionErrors.concurrentStreamLogUpdate(currentBatchId)
+  protected def markMicroBatchEnd(execCtx: MicroBatchExecutionContext): Unit = {
+    watermarkTracker.updateWatermark(execCtx.executionPlan.executedPlan)
+    execCtx.reportTimeTaken("commitOffsets") {
+      if (!commitLog.add(execCtx.batchId, CommitMetadata(watermarkTracker.currentWatermark))) {
+        throw QueryExecutionErrors.concurrentStreamLogUpdate(execCtx.batchId)
       }
     }
-    committedOffsets ++= availableOffsets
+    committedOffsets ++= execCtx.endOffsets
   }
 
-  protected def cleanUpLastExecutedMicroBatch(): Unit = {
-    if (currentBatchId != 0) {
-      val prevBatchOff = offsetLog.get(currentBatchId - 1)
+  protected def cleanUpLastExecutedMicroBatch(execCtx: MicroBatchExecutionContext): Unit = {
+    if (execCtx.batchId != 0) {
+      val prevBatchOff = offsetLog.get(execCtx.batchId - 1)
       if (prevBatchOff.isDefined) {
         commitSources(prevBatchOff.get)
         // The watermark for each batch is given as (prev. watermark, curr. watermark), hence
         // we can't purge the previous version of watermark.
-        watermarkPropagator.purge(currentBatchId - 2)
+        watermarkPropagator.purge(execCtx.batchId - 2)
       } else {
-        throw new IllegalStateException(s"batch ${currentBatchId - 1} doesn't exist")
+        throw new IllegalStateException(s"batch ${execCtx.batchId - 1} doesn't exist")
       }
     }
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
index ccbbf9a4d874..0d32eed9b6bd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.execution.streaming
 
 import java.time.Instant
+import java.time.ZoneId
 import java.time.format.DateTimeFormatter
 import java.util.{Optional, UUID}
 
@@ -40,89 +41,172 @@ import org.apache.spark.util.Clock
 
 /**
  * Responsible for continually reporting statistics about the amount of data processed as well
- * as latency for a streaming query.  This trait is designed to be mixed into the
- * [[StreamExecution]], who is responsible for calling `startTrigger` and `finishTrigger`
- * at the appropriate times. Additionally, the status can updated with `updateStatusMessage` to
- * allow reporting on the streams current state (i.e. "Fetching more data").
+ * as latency for a streaming query.  This class is designed to hold information about
+ * a streaming query and contains methods that can be used on a streaming query,
+ * such as get the most recent progress of the query.
  */
-trait ProgressReporter extends Logging {
+class ProgressReporter(
+    private val sparkSession: SparkSession,
+    private val triggerClock: Clock,
+    val logicalPlan: () => LogicalPlan)
+  extends Logging {
 
-  case class ExecutionStats(
-    inputRows: Map[SparkDataStream, Long],
-    stateOperators: Seq[StateOperatorProgress],
-    eventTimeStats: Map[String, String])
-
-  // Internal state of the stream, required for computing metrics.
-  protected def id: UUID
-  protected def runId: UUID
-  protected def name: String
-  protected def triggerClock: Clock
-  protected def logicalPlan: LogicalPlan
-  protected def lastExecution: QueryExecution
-  protected def newData: Map[SparkDataStream, LogicalPlan]
-  protected def sinkCommitProgress: Option[StreamWriterCommitProgress]
-  protected def sources: Seq[SparkDataStream]
-  protected def sink: Table
+  // The timestamp we report an event that has not executed anything
+  var lastNoExecutionProgressEventTime = Long.MinValue
+
+  /** Holds the most recent query progress updates.  Accesses must lock on the queue itself. */
+  private val progressBuffer = new mutable.Queue[StreamingQueryProgress]()
+
+  val noDataProgressEventInterval: Long =
+    sparkSession.sessionState.conf.streamingNoDataProgressEventInterval
+
+  private val timestampFormat =
+    DateTimeFormatter
+      .ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") // ISO8601
+      .withZone(DateTimeUtils.getZoneId("UTC"))
+
+  /** Returns an array containing the most recent query progress updates. */
+  def recentProgress: Array[StreamingQueryProgress] = progressBuffer.synchronized {
+    progressBuffer.toArray
+  }
+
+  /** Returns the most recent query progress update or null if there were no progress updates. */
+  def lastProgress: StreamingQueryProgress = progressBuffer.synchronized {
+    progressBuffer.lastOption.orNull
+  }
+
+  def updateProgress(newProgress: StreamingQueryProgress): Unit = {
+    // Reset noDataEventTimestamp if we processed any data
+    lastNoExecutionProgressEventTime = triggerClock.getTimeMillis()
+
+    addNewProgress(newProgress)
+    postEvent(new QueryProgressEvent(newProgress))
+    logInfo(s"Streaming query made progress: $newProgress")
+  }
+
+  private def addNewProgress(newProgress: StreamingQueryProgress): Unit = {
+    progressBuffer.synchronized {
+      progressBuffer += newProgress
+      while (progressBuffer.length >= sparkSession.sessionState.conf.streamingProgressRetention) {
+        progressBuffer.dequeue()
+      }
+    }
+  }
+
+  def updateIdleness(
+      id: UUID,
+      runId: UUID,
+      currentTriggerStartTimestamp: Long,
+      newProgress: StreamingQueryProgress): Unit = {
+    val now = triggerClock.getTimeMillis()
+    if (now - noDataProgressEventInterval >= lastNoExecutionProgressEventTime) {
+      addNewProgress(newProgress)
+      if (lastNoExecutionProgressEventTime > Long.MinValue) {
+        postEvent(new QueryIdleEvent(id, runId, formatTimestamp(currentTriggerStartTimestamp)))
+        logInfo(s"Streaming query has been idle and waiting for new data more than " +
+          s"${noDataProgressEventInterval} ms.")
+      }
+
+      lastNoExecutionProgressEventTime = now
+    }
+  }
+
+  private def postEvent(event: StreamingQueryListener.Event): Unit = {
+    sparkSession.streams.postListenerEvent(event)
+  }
+
+  def formatTimestamp(millis: Long): String = {
+    Instant.ofEpochMilli(millis)
+      .atZone(ZoneId.of("Z")).format(timestampFormat)
+  }
+}
+
+/**
+ * This class holds variables and methods that are used track metrics and progress
+ * during the execution lifecycle of a batch that is being processed by the streaming query
+ */
+abstract class ProgressContext(
+    id: UUID,
+    runId: UUID,
+    name: String,
+    triggerClock: Clock,
+    sources: Seq[SparkDataStream],
+    sink: Table,
+    progressReporter: ProgressReporter)
+  extends Logging {
+
+  import ProgressContext._
+
+  // offset metadata for this batch
   protected def offsetSeqMetadata: OffsetSeqMetadata
-  protected def currentBatchId: Long
-  protected def sparkSession: SparkSession
-  protected def postEvent(event: StreamingQueryListener.Event): Unit
+
+  // the most recent input data for each source.
+  protected def newData: Map[SparkDataStream, LogicalPlan]
+
+  /** Flag that signals whether any error with input metrics have already been logged */
+  protected var metricWarningLogged: Boolean = false
+
+  @volatile
+  var sinkCommitProgress: Option[StreamWriterCommitProgress] = None
 
   // Local timestamps and counters.
-  private var currentTriggerStartTimestamp = -1L
+  protected var currentTriggerStartTimestamp = -1L
   private var currentTriggerEndTimestamp = -1L
   private var currentTriggerStartOffsets: Map[SparkDataStream, String] = _
   private var currentTriggerEndOffsets: Map[SparkDataStream, String] = _
   private var currentTriggerLatestOffsets: Map[SparkDataStream, String] = _
 
   // TODO: Restore this from the checkpoint when possible.
-  private var lastTriggerStartTimestamp = -1L
+  protected var lastTriggerStartTimestamp = -1L
 
   private val currentDurationsMs = new mutable.HashMap[String, Long]()
 
-  /** Flag that signals whether any error with input metrics have already been logged */
-  private var metricWarningLogged: Boolean = false
-
-  /** Holds the most recent query progress updates.  Accesses must lock on the queue itself. */
-  private val progressBuffer = new mutable.Queue[StreamingQueryProgress]()
-
-  private val noDataProgressEventInterval =
-    sparkSession.sessionState.conf.streamingNoDataProgressEventInterval
-
-  // The timestamp we report an event that has not executed anything
-  private var lastNoExecutionProgressEventTime = Long.MinValue
-
-  private val timestampFormat =
-    DateTimeFormatter
-      .ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") // ISO8601
-      .withZone(DateTimeUtils.getZoneId("UTC"))
+  // This field tracks the execution stats being calculated during reporting metrics for the
+  // latest executed batch. We track the value to construct the progress for idle trigger which
+  // doesn't execute a batch. Since an idle trigger doesn't execute a batch, it has no idea about
+  // the snapshot of the query status, hence it has to refer to the latest executed batch.
+  @volatile protected var execStatsOnLatestExecutedBatch: Option[ExecutionStats] = None
 
   @volatile
-  protected var currentStatus: StreamingQueryStatus = {
+  var currentStatus: StreamingQueryStatus = {
     new StreamingQueryStatus(
       message = "Initializing StreamExecution",
       isDataAvailable = false,
-      isTriggerActive = false)
+      isTriggerActive = false
+    )
   }
 
   private var latestStreamProgress: StreamProgress = _
 
-  /** Returns the current status of the query. */
-  def status: StreamingQueryStatus = currentStatus
+  /** Records the duration of running `body` for the next query progress update. */
+  def reportTimeTaken[T](triggerDetailKey: String)(body: => T): T = {
+    val startTime = triggerClock.getTimeMillis()
+    val result = body
+    val endTime = triggerClock.getTimeMillis()
+    val timeTaken = math.max(endTime - startTime, 0)
 
-  /** Returns an array containing the most recent query progress updates. */
-  def recentProgress: Array[StreamingQueryProgress] = progressBuffer.synchronized {
-    progressBuffer.toArray
+    reportTimeTaken(triggerDetailKey, timeTaken)
+    result
   }
 
-  /** Returns the most recent query progress update or null if there were no progress updates. */
-  def lastProgress: StreamingQueryProgress = progressBuffer.synchronized {
-    progressBuffer.lastOption.orNull
+  /**
+   * Reports an input duration for a particular detail key in the next query progress
+   * update. Can be used directly instead of reportTimeTaken(key)(body) when the duration
+   * is measured asynchronously.
+   */
+  def reportTimeTaken(triggerDetailKey: String, timeTakenMs: Long): Unit = {
+    val previousTime = currentDurationsMs.getOrElse(triggerDetailKey, 0L)
+    currentDurationsMs.put(triggerDetailKey, previousTime + timeTakenMs)
+    logDebug(s"$triggerDetailKey took $timeTakenMs ms")
+  }
+
+  /** Updates the message returned in `status`. */
+  def updateStatusMessage(message: String): Unit = {
+    currentStatus = currentStatus.copy(message = message)
   }
 
   /** Begins recording statistics about query progress for a given trigger. */
-  protected def startTrigger(): Unit = {
-    logDebug("Starting Trigger Calculation")
+  def startTrigger(): Unit = {
     lastTriggerStartTimestamp = currentTriggerStartTimestamp
     currentTriggerStartTimestamp = triggerClock.getTimeMillis()
     currentTriggerStartOffsets = null
@@ -135,7 +219,7 @@ trait ProgressReporter extends Logging {
    * Record the offsets range this trigger will process. Call this before updating
    * `committedOffsets` in `StreamExecution` to make sure that the correct range is recorded.
    */
-  protected def recordTriggerOffsets(
+  def recordTriggerOffsets(
       from: StreamProgress,
       to: StreamProgress,
       latest: StreamProgress): Unit = {
@@ -143,56 +227,73 @@ trait ProgressReporter extends Logging {
     currentTriggerEndOffsets = to.transform((_, v) => v.json)
     currentTriggerLatestOffsets = latest.transform((_, v) => v.json)
     latestStreamProgress = to
+    currentTriggerLatestOffsets = latest.transform((_, v) => v.json)
   }
 
-  private def addNewProgress(newProgress: StreamingQueryProgress): Unit = {
-    progressBuffer.synchronized {
-      progressBuffer += newProgress
-      while (progressBuffer.length >= sparkSession.sessionState.conf.streamingProgressRetention) {
-        progressBuffer.dequeue()
-      }
-    }
-  }
+  /** Finalizes the trigger which did not execute a batch. */
+  def finishNoExecutionTrigger(lastExecutedEpochId: Long): Unit = {
+    currentTriggerEndTimestamp = triggerClock.getTimeMillis()
+    val processingTimeMills = currentTriggerEndTimestamp - currentTriggerStartTimestamp
 
-  private def updateProgress(newProgress: StreamingQueryProgress): Unit = {
-    // Reset noDataEventTimestamp if we processed any data
-    lastNoExecutionProgressEventTime = triggerClock.getTimeMillis()
+    val execStatsOnNoExecution = execStatsOnLatestExecutedBatch.map(resetExecStatsForNoExecution)
 
-    addNewProgress(newProgress)
-    postEvent(new QueryProgressEvent(newProgress))
-    logInfo(s"Streaming query made progress: $newProgress")
-  }
+    val newProgress = constructNewProgress(processingTimeMills, lastExecutedEpochId,
+      execStatsOnNoExecution, Map.empty[String, Row])
 
-  private def updateIdleness(newProgress: StreamingQueryProgress): Unit = {
-    val now = triggerClock.getTimeMillis()
-    if (now - noDataProgressEventInterval >= lastNoExecutionProgressEventTime) {
-      addNewProgress(newProgress)
-      if (lastNoExecutionProgressEventTime > Long.MinValue) {
-        postEvent(new QueryIdleEvent(newProgress.id, newProgress.runId,
-          formatTimestamp(currentTriggerStartTimestamp)))
-        logInfo(s"Streaming query has been idle and waiting for new data more than " +
-          s"$noDataProgressEventInterval ms.")
-      }
+    progressReporter.updateIdleness(id, runId, currentTriggerStartTimestamp, newProgress)
 
-      lastNoExecutionProgressEventTime = now
-    }
+    warnIfFinishTriggerTakesTooLong(currentTriggerEndTimestamp, processingTimeMills)
+
+    currentStatus = currentStatus.copy(isTriggerActive = false)
+  }
+
+  /**
+   * Retrieve a measured duration
+   */
+  def getDuration(key: String): Option[Long] = {
+    currentDurationsMs.get(key)
   }
 
   /**
    * Finalizes the query progress and adds it to list of recent status updates.
    *
    * @param hasNewData Whether the sources of this stream had new data for this trigger.
-   * @param hasExecuted Whether any batch was executed during this trigger. Streaming queries that
-   *                    perform stateful aggregations with timeouts can still run batches even
-   *                    though the sources don't have any new data.
    */
-  protected def finishTrigger(hasNewData: Boolean, hasExecuted: Boolean): Unit = {
-    assert(currentTriggerStartOffsets != null && currentTriggerEndOffsets != null &&
-      currentTriggerLatestOffsets != null)
+  def finishTrigger(
+      hasNewData: Boolean,
+      sourceToNumInputRowsMap: Map[SparkDataStream, Long],
+      lastExecution: IncrementalExecution,
+      lastEpochId: Long): Unit = {
+    assert(
+      currentTriggerStartOffsets != null && currentTriggerEndOffsets != null &&
+        currentTriggerLatestOffsets != null
+    )
     currentTriggerEndTimestamp = triggerClock.getTimeMillis()
-
-    val executionStats = extractExecutionStats(hasNewData, hasExecuted)
     val processingTimeMills = currentTriggerEndTimestamp - currentTriggerStartTimestamp
+    assert(lastExecution != null, "executed batch should provide the information for execution.")
+    val execStats = extractExecutionStats(hasNewData, sourceToNumInputRowsMap, lastExecution)
+    logDebug(s"Execution stats: $execStats")
+
+    val observedMetrics = extractObservedMetrics(lastExecution)
+    val newProgress = constructNewProgress(processingTimeMills, lastEpochId, Some(execStats),
+      observedMetrics)
+
+    progressReporter.lastNoExecutionProgressEventTime = triggerClock.getTimeMillis()
+    progressReporter.updateProgress(newProgress)
+
+    // Update the value since this trigger executes a batch successfully.
+    this.execStatsOnLatestExecutedBatch = Some(execStats)
+
+    warnIfFinishTriggerTakesTooLong(currentTriggerEndTimestamp, processingTimeMills)
+
+    currentStatus = currentStatus.copy(isTriggerActive = false)
+  }
+
+  private def constructNewProgress(
+      processingTimeMills: Long,
+      batchId: Long,
+      execStats: Option[ExecutionStats],
+      observedMetrics: Map[String, Row]): StreamingQueryProgress = {
     val processingTimeSec = Math.max(1L, processingTimeMills).toDouble / MILLIS_PER_SECOND
 
     val inputTimeSec = if (lastTriggerStartTimestamp >= 0) {
@@ -200,10 +301,39 @@ trait ProgressReporter extends Logging {
     } else {
       Double.PositiveInfinity
     }
-    logDebug(s"Execution stats: $executionStats")
+    val sourceProgress = extractSourceProgress(execStats, inputTimeSec, processingTimeSec)
+    val sinkProgress = extractSinkProgress(execStats)
 
-    val sourceProgress = sources.distinct.map { source =>
-      val numRecords = executionStats.inputRows.getOrElse(source, 0L)
+    val eventTime = execStats.map { stats =>
+      stats.eventTimeStats.asJava
+    }.getOrElse(new java.util.HashMap)
+
+    val stateOperators = execStats.map { stats =>
+      stats.stateOperators.toArray
+    }.getOrElse(Array[StateOperatorProgress]())
+
+    new StreamingQueryProgress(
+      id = id,
+      runId = runId,
+      name = name,
+      timestamp = progressReporter.formatTimestamp(currentTriggerStartTimestamp),
+      batchId = batchId,
+      batchDuration = processingTimeMills,
+      durationMs =
+        new java.util.HashMap(currentDurationsMs.toMap.transform((_, v) => long2Long(v)).asJava),
+      eventTime = new java.util.HashMap(eventTime),
+      stateOperators = stateOperators,
+      sources = sourceProgress.toArray,
+      sink = sinkProgress,
+      observedMetrics = new java.util.HashMap(observedMetrics.asJava))
+  }
+
+  private def extractSourceProgress(
+      execStats: Option[ExecutionStats],
+      inputTimeSec: Double,
+      processingTimeSec: Double): Seq[SourceProgress] = {
+    sources.distinct.map { source =>
+      val numRecords = execStats.flatMap(_.inputRows.get(source)).getOrElse(0L)
       val sourceMetrics = source match {
         case withMetrics: ReportsSourceMetrics =>
           withMetrics.metrics(Optional.ofNullable(latestStreamProgress.get(source).orNull))
@@ -220,94 +350,47 @@ trait ProgressReporter extends Logging {
         metrics = sourceMetrics
       )
     }
+  }
 
-    val sinkOutput = if (hasExecuted) {
-      sinkCommitProgress.map(_.numOutputRows)
-    } else {
-      sinkCommitProgress.map(_ => 0L)
-    }
-
+  private def extractSinkProgress(execStats: Option[ExecutionStats]): SinkProgress = {
+    val sinkOutput = execStats.flatMap(_.outputRows)
     val sinkMetrics = sink match {
-      case withMetrics: ReportsSinkMetrics =>
-        withMetrics.metrics()
+      case withMetrics: ReportsSinkMetrics => withMetrics.metrics()
       case _ => Map[String, String]().asJava
     }
 
-    val sinkProgress = SinkProgress(
-      sink.toString, sinkOutput, sinkMetrics)
-
-    val observedMetrics = extractObservedMetrics(hasNewData, lastExecution)
-
-    val newProgress = new StreamingQueryProgress(
-      id = id,
-      runId = runId,
-      name = name,
-      timestamp = formatTimestamp(currentTriggerStartTimestamp),
-      batchId = currentBatchId,
-      batchDuration = processingTimeMills,
-      durationMs =
-        new java.util.HashMap(currentDurationsMs.toMap.transform((_, v) => long2Long(v)).asJava),
-      eventTime = new java.util.HashMap(executionStats.eventTimeStats.asJava),
-      stateOperators = executionStats.stateOperators.toArray,
-      sources = sourceProgress.toArray,
-      sink = sinkProgress,
-      observedMetrics = new java.util.HashMap(observedMetrics.asJava))
-
-    if (hasExecuted) {
-      updateProgress(newProgress)
-    } else {
-      updateIdleness(newProgress)
-    }
-
-    currentStatus = currentStatus.copy(isTriggerActive = false)
+    SinkProgress(sink.toString, sinkOutput, sinkMetrics)
   }
 
-  /** Extract statistics about stateful operators from the executed query plan. */
-  private def extractStateOperatorMetrics(hasExecuted: Boolean): Seq[StateOperatorProgress] = {
-    if (lastExecution == null) return Nil
-    // lastExecution could belong to one of the previous triggers if `!hasExecuted`.
-    // Walking the plan again should be inexpensive.
-    lastExecution.executedPlan.collect {
-      case p if p.isInstanceOf[StateStoreWriter] =>
-        val progress = p.asInstanceOf[StateStoreWriter].getProgress()
-        if (hasExecuted) {
-          progress
-        } else {
-          progress.copy(newNumRowsUpdated = 0, newNumRowsDroppedByWatermark = 0)
-        }
-    }
+  /**
+   * Override of finishTrigger to extract the map from IncrementalExecution.
+   */
+  def finishTrigger(
+      hasNewData: Boolean,
+      lastExecution: IncrementalExecution,
+      lastEpoch: Long): Unit = {
+    val map: Map[SparkDataStream, Long] =
+      if (hasNewData) extractSourceToNumInputRows(lastExecution) else Map.empty
+    finishTrigger(hasNewData, map, lastExecution, lastEpoch)
   }
 
-  /** Extracts statistics from the most recent query execution. */
-  private def extractExecutionStats(hasNewData: Boolean, hasExecuted: Boolean): ExecutionStats = {
-    val hasEventTime = logicalPlan.collect { case e: EventTimeWatermark => e }.nonEmpty
-    val watermarkTimestamp =
-      if (hasEventTime) Map("watermark" -> formatTimestamp(offsetSeqMetadata.batchWatermarkMs))
-      else Map.empty[String, String]
-
-    // SPARK-19378: Still report metrics even though no data was processed while reporting progress.
-    val stateOperators = extractStateOperatorMetrics(hasExecuted)
-
-    if (!hasNewData) {
-      return ExecutionStats(Map.empty, stateOperators, watermarkTimestamp)
+  private def warnIfFinishTriggerTakesTooLong(
+      triggerEndTimestamp: Long,
+      processingTimeMills: Long): Unit = {
+    // Log a warning message if finishTrigger step takes more time than processing the batch and
+    // also longer than min threshold (1 minute).
+    val finishTriggerDurationMillis = triggerClock.getTimeMillis() - triggerEndTimestamp
+    val thresholdForLoggingMillis = 60 * 1000
+    if (finishTriggerDurationMillis > math.max(thresholdForLoggingMillis, processingTimeMills)) {
+      logWarning("Query progress update takes longer than batch processing time. Progress " +
+        s"update takes $finishTriggerDurationMillis milliseconds. Batch processing takes " +
+        s"$processingTimeMills milliseconds")
     }
-
-    val numInputRows = extractSourceToNumInputRows()
-
-    val eventTimeStats = lastExecution.executedPlan.collect {
-      case e: EventTimeWatermarkExec if e.eventTimeStats.value.count > 0 =>
-        val stats = e.eventTimeStats.value
-        Map(
-          "max" -> stats.max,
-          "min" -> stats.min,
-          "avg" -> stats.avg.toLong).transform((_, v) => formatTimestamp(v))
-    }.headOption.getOrElse(Map.empty) ++ watermarkTimestamp
-
-    ExecutionStats(numInputRows, stateOperators, eventTimeStats)
   }
 
   /** Extract number of input sources for each streaming source in plan */
-  private def extractSourceToNumInputRows(): Map[SparkDataStream, Long] = {
+  private def extractSourceToNumInputRows(
+      lastExecution: IncrementalExecution): Map[SparkDataStream, Long] = {
 
     def sumRows(tuples: Seq[(SparkDataStream, Long)]): Map[SparkDataStream, Long] = {
       tuples.groupBy(_._1).transform((_, v) => v.map(_._2).sum) // sum up rows for each source
@@ -328,7 +411,7 @@ trait ProgressReporter extends Logging {
 
     val onlyDataSourceV2Sources = {
       // Check whether the streaming query's logical plan has only V2 micro-batch data sources
-      val allStreamingLeaves = logicalPlan.collect {
+      val allStreamingLeaves = progressReporter.logicalPlan().collect {
         case s: StreamingDataSourceV2ScanRelation => s.stream.isInstanceOf[MicroBatchStream]
         case _: StreamingExecutionRelation => false
       }
@@ -414,35 +497,91 @@ trait ProgressReporter extends Logging {
     }
   }
 
-  /** Extracts observed metrics from the most recent query execution. */
-  private def extractObservedMetrics(
-      hasNewData: Boolean,
-      lastExecution: QueryExecution): Map[String, Row] = {
-    if (!hasNewData || lastExecution == null) {
-      return Map.empty
+  /** Extract statistics about stateful operators from the executed query plan. */
+  private def extractStateOperatorMetrics(
+      lastExecution: IncrementalExecution): Seq[StateOperatorProgress] = {
+    assert(lastExecution != null, "lastExecution is not available")
+    lastExecution.executedPlan.collect {
+      case p if p.isInstanceOf[StateStoreWriter] =>
+        p.asInstanceOf[StateStoreWriter].getProgress()
     }
-    lastExecution.observedMetrics
   }
 
-  /** Records the duration of running `body` for the next query progress update. */
-  protected def reportTimeTaken[T](triggerDetailKey: String)(body: => T): T = {
-    val startTime = triggerClock.getTimeMillis()
-    val result = body
-    val endTime = triggerClock.getTimeMillis()
-    val timeTaken = math.max(endTime - startTime, 0)
+  /** Extracts statistics from the most recent query execution. */
+  private def extractExecutionStats(
+      hasNewData: Boolean,
+      sourceToNumInputRows: Map[SparkDataStream, Long],
+      lastExecution: IncrementalExecution): ExecutionStats = {
+    val hasEventTime = progressReporter.logicalPlan().collect {
+      case e: EventTimeWatermark => e
+    }.nonEmpty
 
-    val previousTime = currentDurationsMs.getOrElse(triggerDetailKey, 0L)
-    currentDurationsMs.put(triggerDetailKey, previousTime + timeTaken)
-    logDebug(s"$triggerDetailKey took $timeTaken ms")
-    result
+    val watermarkTimestamp =
+      if (hasEventTime) {
+        Map("watermark" -> progressReporter.formatTimestamp(offsetSeqMetadata.batchWatermarkMs))
+      } else Map.empty[String, String]
+
+    // SPARK-19378: Still report metrics even though no data was processed while reporting progress.
+    val stateOperators = extractStateOperatorMetrics(lastExecution)
+
+    val sinkOutput = sinkCommitProgress.map(_.numOutputRows)
+
+    if (!hasNewData) {
+      return ExecutionStats(Map.empty, stateOperators, watermarkTimestamp, sinkOutput)
+    }
+
+    val eventTimeStats = lastExecution.executedPlan
+      .collect {
+        case e: EventTimeWatermarkExec if e.eventTimeStats.value.count > 0 =>
+          val stats = e.eventTimeStats.value
+          Map(
+            "max" -> stats.max,
+            "min" -> stats.min,
+            "avg" -> stats.avg.toLong).transform((_, v) => progressReporter.formatTimestamp(v))
+      }.headOption.getOrElse(Map.empty) ++ watermarkTimestamp
+
+    ExecutionStats(sourceToNumInputRows, stateOperators, eventTimeStats.toMap, sinkOutput)
   }
 
-  protected def formatTimestamp(millis: Long): String = {
-    timestampFormat.format(Instant.ofEpochMilli(millis))
+  /**
+   * Reset values in the execution stats to removes the values which are specific to the batch.
+   * New execution stats will only retain the values as a snapshot of the query status.
+   * (E.g. for stateful operators, numRowsTotal is a snapshot of the status, whereas
+   * numRowsUpdated is bound to the batch.)
+   * TODO: We do not seem to clear up all values in StateOperatorProgress which are bound to the
+   * batch. Fix this.
+   */
+  private def resetExecStatsForNoExecution(originExecStats: ExecutionStats): ExecutionStats = {
+    val newStatefulOperators = originExecStats.stateOperators.map { so =>
+      so.copy(newNumRowsUpdated = 0, newNumRowsDroppedByWatermark = 0)
+    }
+    val newEventTimeStats = if (originExecStats.eventTimeStats.contains("watermark")) {
+      Map("watermark" -> progressReporter.formatTimestamp(offsetSeqMetadata.batchWatermarkMs))
+    } else {
+      Map.empty[String, String]
+    }
+    val newOutputRows = originExecStats.outputRows.map(_ => 0L)
+    originExecStats.copy(
+      inputRows = Map.empty[SparkDataStream, Long],
+      stateOperators = newStatefulOperators,
+      eventTimeStats = newEventTimeStats,
+      outputRows = newOutputRows)
   }
 
-  /** Updates the message returned in `status`. */
-  protected def updateStatusMessage(message: String): Unit = {
-    currentStatus = currentStatus.copy(message = message)
+  /** Extracts observed metrics from the most recent query execution. */
+  private def extractObservedMetrics(
+      lastExecution: QueryExecution): Map[String, Row] = {
+    if (lastExecution == null) {
+      return Map.empty
+    }
+    lastExecution.observedMetrics
   }
 }
+
+object ProgressContext {
+  case class ExecutionStats(
+    inputRows: Map[SparkDataStream, Long],
+    stateOperators: Seq[StateOperatorProgress],
+    eventTimeStats: Map[String, String],
+    outputRows: Option[Long])
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index aac26a727689..859fce8b1154 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -40,7 +40,6 @@ import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table}
 import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, ReadLimit, SparkDataStream}
 import org.apache.spark.sql.connector.write.{LogicalWriteInfoImpl, SupportsTruncate, Write}
 import org.apache.spark.sql.execution.command.StreamingExplainCommand
-import org.apache.spark.sql.execution.datasources.v2.StreamWriterCommitProgress
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.internal.connector.SupportsStreamingUpdateAsAppend
 import org.apache.spark.sql.streaming._
@@ -74,7 +73,7 @@ abstract class StreamExecution(
     val triggerClock: Clock,
     val outputMode: OutputMode,
     deleteCheckpointOnStop: Boolean)
-  extends StreamingQuery with ProgressReporter with Logging {
+  extends StreamingQuery with Logging {
 
   import org.apache.spark.sql.streaming.StreamingQueryListener._
 
@@ -93,8 +92,6 @@ abstract class StreamExecution(
   private val startLatch = new CountDownLatch(1)
   private val terminationLatch = new CountDownLatch(1)
 
-  def logicalPlan: LogicalPlan
-
   /**
    * Tracks how much data we have processed and committed to the sink or state store from each
    * input source.
@@ -102,33 +99,50 @@ abstract class StreamExecution(
    * Other threads should make a shallow copy if they are going to access this field more than
    * once, since the field's value may change at any time.
    */
-  @volatile
-  var committedOffsets = new StreamProgress
+  @volatile var committedOffsets = new StreamProgress
 
   /**
-   * Tracks the offsets that are available to be processed, but have not yet be committed to the
-   * sink.
-   * Only the scheduler thread should modify this field, and only in atomic steps.
-   * Other threads should make a shallow copy if they are going to access this field more than
-   * once, since the field's value may change at any time.
+   * Get the latest execution context .
    */
-  @volatile
-  var availableOffsets = new StreamProgress
+  def getLatestExecutionContext(): StreamExecutionContext
 
   /**
-   * Tracks the latest offsets for each input source.
-   * Only the scheduler thread should modify this field, and only in atomic steps.
-   * Other threads should make a shallow copy if they are going to access this field more than
-   * once, since the field's value may change at any time.
+   * Get the start offsets of the latest batch that has been planned
    */
-  @volatile
-  var latestOffsets = new StreamProgress
+  def getStartOffsetsOfLatestBatch: StreamProgress = {
+    getLatestExecutionContext().startOffsets
+  }
 
-  @volatile
-  var sinkCommitProgress: Option[StreamWriterCommitProgress] = None
+  /**
+   * Get the end or formerly know as "available" offsets of the latest batch that has been planned
+   */
+  def availableOffsets: StreamProgress = {
+    getLatestExecutionContext().endOffsets
+  }
 
-  /** The current batchId or -1 if execution has not yet been initialized. */
-  protected var currentBatchId: Long = -1
+  def latestOffsets: StreamProgress = {
+    getLatestExecutionContext().latestOffsets
+  }
+
+  override def status: StreamingQueryStatus = {
+    getLatestExecutionContext().currentStatus
+  }
+
+  override def recentProgress: Array[StreamingQueryProgress] = progressReporter.recentProgress
+
+  override def lastProgress: StreamingQueryProgress = progressReporter.lastProgress
+
+  /**
+   * The base logical plan which will be used across batch runs. Once the value is set, it should
+   * not be modified.
+   */
+  def logicalPlan: LogicalPlan
+
+  /**
+   * The list of stream instances which will be used across batch runs. Once the value is set,
+   * it should not be modified.
+   */
+  protected def sources: Seq[SparkDataStream]
 
   /** Metadata associated with the whole query */
   protected val streamMetadata: StreamMetadata = {
@@ -141,10 +155,6 @@ abstract class StreamExecution(
     }
   }
 
-  /** Metadata associated with the offset seq of a batch in the query. */
-  protected var offsetSeqMetadata = OffsetSeqMetadata(
-    batchWatermarkMs = 0, batchTimestampMs = 0, sparkSession.conf)
-
   /**
    * A map of current watermarks, keyed by the position of the watermark operator in the
    * physical plan.
@@ -159,6 +169,9 @@ abstract class StreamExecution(
 
   override val runId: UUID = UUID.randomUUID
 
+  protected val progressReporter = new ProgressReporter(sparkSession, triggerClock,
+    () => logicalPlan)
+
   /**
    * Pretty identified string of printing in logs. Format is
    * If name is set "queryName [id = xyz, runId = abc]" else "[id = xyz, runId = abc]"
@@ -174,11 +187,7 @@ abstract class StreamExecution(
   /** Defines the internal state of execution */
   protected val state = new AtomicReference[State](INITIALIZING)
 
-  @volatile
-  var lastExecution: IncrementalExecution = _
-
-  /** Holds the most recent input data for each source. */
-  protected var newData: Map[SparkDataStream, LogicalPlan] = _
+  def lastExecution: IncrementalExecution = getLatestExecutionContext().executionPlan
 
   @volatile
   protected var streamDeathCause: StreamingQueryException = null
@@ -280,7 +289,8 @@ abstract class StreamExecution(
 
       // `postEvent` does not throw non fatal exception.
       val startTimestamp = triggerClock.getTimeMillis()
-      postEvent(new QueryStartedEvent(id, runId, name, formatTimestamp(startTimestamp)))
+      postEvent(
+        new QueryStartedEvent(id, runId, name, progressReporter.formatTimestamp(startTimestamp)))
 
       // Unblock starting thread
       startLatch.countDown()
@@ -298,18 +308,18 @@ abstract class StreamExecution(
         sparkSessionForStream.conf.set(SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION.key,
           "false")
 
-        updateStatusMessage("Initializing sources")
+        getLatestExecutionContext().updateStatusMessage("Initializing sources")
         // force initialization of the logical plan so that the sources can be created
         logicalPlan
 
-        offsetSeqMetadata = OffsetSeqMetadata(
+        getLatestExecutionContext().offsetSeqMetadata = OffsetSeqMetadata(
           batchWatermarkMs = 0, batchTimestampMs = 0, sparkSessionForStream.conf)
 
         if (state.compareAndSet(INITIALIZING, ACTIVE)) {
           // Unblock `awaitInitialization`
           initializationLatch.countDown()
           runActivatedStream(sparkSessionForStream)
-          updateStatusMessage("Stopped")
+          getLatestExecutionContext().updateStatusMessage("Stopped")
         } else {
           // `stop()` is already called. Let `finally` finish the cleanup.
         }
@@ -317,25 +327,31 @@ abstract class StreamExecution(
     } catch {
       case e if isInterruptedByStop(e, sparkSession.sparkContext) =>
         // interrupted by stop()
-        updateStatusMessage("Stopped")
+        getLatestExecutionContext().updateStatusMessage("Stopped")
       case e: Throwable =>
         val message = if (e.getMessage == null) "" else e.getMessage
         streamDeathCause = new StreamingQueryException(
           toDebugString(includeLogicalPlan = isInitialized),
           cause = e,
-          committedOffsets.toOffsetSeq(sources, offsetSeqMetadata).toString,
-          availableOffsets.toOffsetSeq(sources, offsetSeqMetadata).toString,
+          getLatestExecutionContext().startOffsets
+            .toOffsetSeq(sources.toSeq, getLatestExecutionContext().offsetSeqMetadata)
+            .toString,
+          getLatestExecutionContext().endOffsets
+            .toOffsetSeq(sources.toSeq, getLatestExecutionContext().offsetSeqMetadata)
+            .toString,
           errorClass = "STREAM_FAILED",
           messageParameters = Map(
             "id" -> id.toString,
             "runId" -> runId.toString,
             "message" -> message,
             "queryDebugString" -> toDebugString(includeLogicalPlan = isInitialized),
-            "startOffset" -> committedOffsets.toOffsetSeq(sources, offsetSeqMetadata).toString,
-            "endOffset" -> availableOffsets.toOffsetSeq(sources, offsetSeqMetadata).toString
+            "startOffset" -> getLatestExecutionContext().startOffsets.toOffsetSeq(
+              sources.toSeq, getLatestExecutionContext().offsetSeqMetadata).toString,
+            "endOffset" -> getLatestExecutionContext().endOffsets.toOffsetSeq(
+              sources.toSeq, getLatestExecutionContext().offsetSeqMetadata).toString
           ))
         logError(s"Query $prettyIdString terminated with error", e)
-        updateStatusMessage(s"Terminated with exception: $message")
+        getLatestExecutionContext().updateStatusMessage(s"Terminated with exception: $message")
         // Rethrow the fatal errors to allow the user using `Thread.UncaughtExceptionHandler` to
         // handle them
         if (!NonFatal(e)) {
@@ -355,7 +371,8 @@ abstract class StreamExecution(
         stopSources()
         cleanup()
         state.set(TERMINATED)
-        currentStatus = status.copy(isTriggerActive = false, isDataAvailable = false)
+        getLatestExecutionContext().currentStatus =
+          status.copy(isTriggerActive = false, isDataAvailable = false)
 
         // Update metrics and status
         sparkSession.sparkContext.env.metricsSystem.removeSource(streamMetrics)
@@ -408,7 +425,7 @@ abstract class StreamExecution(
     }
   }
 
-  override protected def postEvent(event: StreamingQueryListener.Event): Unit = {
+  protected def postEvent(event: StreamingQueryListener.Event): Unit = {
     sparkSession.streams.postListenerEvent(event)
   }
 
@@ -592,8 +609,8 @@ abstract class StreamExecution(
     val debugString =
       s"""|=== Streaming Query ===
           |Identifier: $prettyIdString
-          |Current Committed Offsets: $committedOffsets
-          |Current Available Offsets: $availableOffsets
+          |Current Committed Offsets: ${getLatestExecutionContext().startOffsets}
+          |Current Available Offsets: ${getLatestExecutionContext().endOffsets}
           |
           |Current State: $state
           |Thread State: ${queryExecutionThread.getState}""".stripMargin
@@ -605,7 +622,8 @@ abstract class StreamExecution(
   }
 
   protected def getBatchDescriptionString: String = {
-    val batchDescription = if (currentBatchId < 0) "init" else currentBatchId.toString
+    val batchDescription = if (getLatestExecutionContext().batchId < 0) "init"
+    else getLatestExecutionContext().batchId.toString
     s"""|${Option(name).getOrElse("")}
         |id = $id
         |runId = $runId
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecutionContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecutionContext.scala
new file mode 100644
index 000000000000..c5e14df3e20e
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecutionContext.scala
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.UUID
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.connector.catalog.Table
+import org.apache.spark.sql.connector.read.streaming.SparkDataStream
+import org.apache.spark.util.Clock
+
+/**
+ *  Holds the mutable state and metrics for a single batch for streaming query.
+ */
+abstract class StreamExecutionContext(
+    val id: UUID,
+    runId: UUID,
+    name: String,
+    triggerClock: Clock,
+    sources: Seq[SparkDataStream],
+    sink: Table,
+    progressReporter: ProgressReporter,
+    var batchId: Long,
+    sparkSession: SparkSession)
+  extends ProgressContext(id, runId, name, triggerClock, sources, sink, progressReporter) {
+
+  /** Metadata associated with the offset seq of a batch in the query. */
+  @volatile
+  var offsetSeqMetadata: OffsetSeqMetadata = OffsetSeqMetadata(
+    batchWatermarkMs = 0, batchTimestampMs = 0, sparkSession.conf)
+
+  /** Holds the most recent input data for each source. */
+  var newData: Map[SparkDataStream, LogicalPlan] = _
+
+  /**
+   * Stores the start offset for this batch.
+   * Only the scheduler thread should modify this field, and only in atomic steps.
+   * Other threads should make a shallow copy if they are going to access this field more than
+   * once, since the field's value may change at any time.
+   */
+  @volatile
+  var startOffsets = new StreamProgress
+
+  /**
+   * Stores the end offsets for this batch.
+   * Only the scheduler thread should modify this field, and only in atomic steps.
+   * Other threads should make a shallow copy if they are going to access this field more than
+   * once, since the field's value may change at any time.
+   */
+  @volatile
+  var endOffsets = new StreamProgress
+
+  /**
+   * Tracks the latest offsets for each input source.
+   * Only the scheduler thread should modify this field, and only in atomic steps.
+   * Other threads should make a shallow copy if they are going to access this field more than
+   * once, since the field's value may change at any time.
+   */
+  @volatile
+  var latestOffsets = new StreamProgress
+
+  @volatile var executionPlan: IncrementalExecution = _
+
+  // Called at the end of the execution.
+  def onExecutionComplete(): Unit = {}
+
+  // Called at time when execution fails.
+  def onExecutionFailure(): Unit = {}
+}
+
+/**
+ * Holds the all mutable state and metrics for a epoch when using continuous execution mode
+ */
+class ContinuousExecutionContext(
+    id: UUID,
+    runId: UUID,
+    name: String,
+    triggerClock: Clock,
+    sources: Seq[SparkDataStream],
+    sink: Table,
+    progressReporter: ProgressReporter,
+    epochId: Long,
+    sparkSession: SparkSession)
+  extends StreamExecutionContext(
+    id,
+    runId,
+    name,
+    triggerClock,
+    sources,
+    sink,
+    progressReporter,
+    epochId,
+    sparkSession)
+
+/**
+ * Holds the all the mutable state and processing metrics for a single micro-batch
+ * when using micro batch execution mode.
+ *
+ * @param _batchId the id of this batch
+ * @param previousContext the execution context of the previous micro-batch
+ */
+class MicroBatchExecutionContext(
+    id: UUID,
+    runId: UUID,
+    name: String,
+    triggerClock: Clock,
+    sources: Seq[SparkDataStream],
+    sink: Table,
+    progressReporter: ProgressReporter,
+    var _batchId: Long,
+    sparkSession: SparkSession,
+    var previousContext: Option[MicroBatchExecutionContext])
+  extends StreamExecutionContext(
+    id,
+    runId,
+    name,
+    triggerClock,
+    sources,
+    sink,
+    progressReporter,
+    _batchId,
+    sparkSession) with Logging {
+
+  /**
+   * Signifies whether current batch (i.e. for the batch `currentBatchId`) has been constructed
+   * (i.e. written to the offsetLog) and is ready for execution.
+   */
+  var isCurrentBatchConstructed = false
+
+  // copying some of the state from the previous batch
+  previousContext.foreach { ctx =>
+    {
+      // the start offsets are the end offsets for the previous batch
+      startOffsets = ctx.endOffsets
+
+      // needed for sources that support admission control as the start offset needs
+      // to be provided
+      endOffsets = ctx.endOffsets
+
+      latestOffsets = ctx.latestOffsets
+
+      // need to carry this over from previous batch since this gets set once and remains
+      // the same value for the rest of the run
+      metricWarningLogged = ctx.metricWarningLogged
+
+      // need to carry this over to track to know when the previous batch started
+      currentTriggerStartTimestamp = ctx.currentTriggerStartTimestamp
+
+      // needed to carry over to new batch because api accessing this value does not expect
+      // it to be null even if its the old plan. For constructing the progress on idle trigger
+      // no longer relies on executionPlan - we use carryOverExecStatsOnLatestExecutedBatch().
+      executionPlan = ctx.executionPlan
+
+      // needs to be carried over to new batch to output metrics for sink
+      // even when no data is processed.
+      sinkCommitProgress = ctx.sinkCommitProgress
+
+      // needed for test org.apache.spark.sql.streaming.EventTimeWatermarkSuite
+      // - recovery from Spark ver 2.3.1 commit log without commit metadata (SPARK-24699)
+      offsetSeqMetadata = ctx.offsetSeqMetadata
+    }
+  }
+
+  def getNextContext(): MicroBatchExecutionContext = {
+    new MicroBatchExecutionContext(
+      id,
+      runId,
+      name,
+      triggerClock,
+      sources,
+      sink,
+      progressReporter,
+      batchId + 1,
+      sparkSession,
+      Some(this))
+  }
+
+  override def startTrigger(): Unit = {
+    super.startTrigger()
+    currentStatus = currentStatus.copy(isTriggerActive = true)
+  }
+
+  override def onExecutionComplete(): Unit = {
+    // Release the ref to avoid infinite chain.
+    previousContext = None
+    super.onExecutionComplete()
+  }
+
+  override def onExecutionFailure(): Unit = {
+    // Release the ref to avoid infinite chain.
+    previousContext = None
+    super.onExecutionFailure()
+  }
+
+  override def toString: String = s"MicroBatchExecutionContext(batchId=$batchId," +
+    s" isCurrentBatchConstructed=$isCurrentBatchConstructed," +
+    s" offsetSeqMetadata=$offsetSeqMetadata," +
+    s" sinkCommitProgress=$sinkCommitProgress," +
+    s" endOffsets$endOffsets," +
+    s" startOffsets=$startOffsets," +
+    s" latestOffsets=$latestOffsets)," +
+    s" executionPlan=${executionPlan}," +
+    s" currentStatus: ${currentStatus}"
+
+  def carryOverExecStatsOnLatestExecutedBatch(): Unit = {
+    execStatsOnLatestExecutedBatch = previousContext.flatMap(_.execStatsOnLatestExecutedBatch)
+  }
+
+  def getStartTime(): Long = {
+    currentTriggerStartTimestamp
+  }
+}
+
+case class MicroBatchExecutionResult(isActive: Boolean, didExecute: Boolean)
+
+case class MicroBatchExecutionFailed() extends RuntimeException
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala
index e807471b12d1..143230759724 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala
@@ -22,10 +22,21 @@ import org.apache.spark.util.{Clock, SystemClock}
 
 trait TriggerExecutor {
 
+  private var execCtx: MicroBatchExecutionContext = _
+
   /**
    * Execute batches using `batchRunner`. If `batchRunner` runs `false`, terminate the execution.
    */
-  def execute(batchRunner: () => Boolean): Unit
+  def execute(batchRunner: (MicroBatchExecutionContext) => Boolean): Unit
+
+  def setNextBatch(execContext: MicroBatchExecutionContext): Unit = {
+    execCtx = execContext
+  }
+
+  protected def runOneBatch(batchRunner: (MicroBatchExecutionContext)
+    => Boolean): Boolean = {
+    batchRunner(execCtx)
+  }
 }
 
 /**
@@ -36,7 +47,9 @@ case class SingleBatchExecutor() extends TriggerExecutor {
   /**
    * Execute a single batch using `batchRunner`.
    */
-  override def execute(batchRunner: () => Boolean): Unit = batchRunner()
+  override def execute(batchRunner: (MicroBatchExecutionContext) => Boolean): Unit = {
+    runOneBatch(batchRunner)
+  }
 }
 
 /**
@@ -46,7 +59,8 @@ case class MultiBatchExecutor() extends TriggerExecutor {
   /**
    * Execute multiple batches using `batchRunner`
    */
-  override def execute(batchRunner: () => Boolean): Unit = while (batchRunner()) {}
+  override def execute(batchRunner: (MicroBatchExecutionContext) => Boolean): Unit
+    = while (runOneBatch(batchRunner)) {}
 }
 
 /**
@@ -60,11 +74,11 @@ case class ProcessingTimeExecutor(
   private val intervalMs = processingTimeTrigger.intervalMs
   require(intervalMs >= 0)
 
-  override def execute(triggerHandler: () => Boolean): Unit = {
+  override def execute(triggerHandler: (MicroBatchExecutionContext) => Boolean): Unit = {
     while (true) {
       val triggerTimeMs = clock.getTimeMillis()
       val nextTriggerTimeMs = nextBatchTime(triggerTimeMs)
-      val terminated = !triggerHandler()
+      val terminated = !runOneBatch(triggerHandler)
       if (intervalMs > 0) {
         val batchElapsedTimeMs = clock.getTimeMillis() - triggerTimeMs
         if (batchElapsedTimeMs > intervalMs) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
index 1de05931faf5..920a7c68314b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
@@ -52,6 +52,14 @@ class ContinuousExecution(
     sparkSession, plan.name, plan.resolvedCheckpointLocation, plan.inputQuery, plan.sink,
     trigger, triggerClock, plan.outputMode, plan.deleteCheckpointOnStop) {
 
+  // needs to be a lazy val because some of the parameters will not be initialized yet
+  private lazy val latestExecutionContext: StreamExecutionContext = new ContinuousExecutionContext(
+    id, runId, name, triggerClock, sources, sink, progressReporter, -1, sparkSession)
+
+  override def getLatestExecutionContext(): StreamExecutionContext = {
+    latestExecutionContext
+  }
+
   @volatile protected var sources: Seq[ContinuousStream] = Seq()
 
   // For use only in test harnesses.
@@ -162,27 +170,28 @@ class ContinuousExecution(
    *  DONE
    */
   private def getStartOffsets(): OffsetSeq = {
+    val execCtx = latestExecutionContext.asInstanceOf[ContinuousExecutionContext]
     // Note that this will need a slight modification for exactly once. If ending offsets were
     // reported but not committed for any epochs, we must replay exactly to those offsets.
     // For at least once, we can just ignore those reports and risk duplicates.
     commitLog.getLatest() match {
       case Some((latestEpochId, _)) =>
-        updateStatusMessage("Starting new streaming query " +
+        execCtx.updateStatusMessage("Starting new streaming query " +
           s"and getting offsets from latest epoch $latestEpochId")
         val nextOffsets = offsetLog.get(latestEpochId).getOrElse {
           throw new IllegalStateException(
             s"Batch $latestEpochId was committed without end epoch offsets!")
         }
         committedOffsets = nextOffsets.toStreamProgress(sources)
-        currentBatchId = latestEpochId + 1
+        execCtx.batchId = latestEpochId + 1
 
-        logDebug(s"Resuming at epoch $currentBatchId with committed offsets $committedOffsets")
+        logDebug(s"Resuming at epoch ${execCtx.batchId} with start offsets ${execCtx.startOffsets}")
         nextOffsets
       case None =>
         // We are starting this stream for the first time. Offsets are all None.
-        updateStatusMessage("Starting new streaming query")
+        execCtx.updateStatusMessage("Starting new streaming query")
         logInfo(s"Starting new streaming query.")
-        currentBatchId = 0
+        execCtx.batchId = 0
         OffsetSeq.fill(sources.map(_ => null): _*)
     }
   }
@@ -193,8 +202,9 @@ class ContinuousExecution(
    */
   private def runContinuous(sparkSessionForQuery: SparkSession): Unit = {
     val offsets = getStartOffsets()
+    val execCtx = latestExecutionContext
 
-    if (currentBatchId > 0) {
+    if (execCtx.batchId > 0) {
       AcceptsLatestSeenOffsetHandler.setLatestSeenOffsetOnSources(Some(offsets), sources)
     }
 
@@ -212,20 +222,20 @@ class ContinuousExecution(
           " not yet supported for continuous processing")
     }
 
-    reportTimeTaken("queryPlanning") {
-      lastExecution = new IncrementalExecution(
+    execCtx.reportTimeTaken("queryPlanning") {
+      execCtx.executionPlan = new IncrementalExecution(
         sparkSessionForQuery,
         withNewSources,
         outputMode,
         checkpointFile("state"),
         id,
         runId,
-        currentBatchId,
+        execCtx.batchId,
         None,
-        offsetSeqMetadata,
+        execCtx.offsetSeqMetadata,
         WatermarkPropagator.noop(),
         false)
-      lastExecution.executedPlan // Force the lazy generation of execution plan
+      execCtx.executionPlan.executedPlan // Force the lazy generation of execution plan
     }
 
     val stream = withNewSources.collect {
@@ -236,7 +246,7 @@ class ContinuousExecution(
     sparkSessionForQuery.sparkContext.setLocalProperty(
       StreamExecution.IS_CONTINUOUS_PROCESSING, true.toString)
     sparkSessionForQuery.sparkContext.setLocalProperty(
-      ContinuousExecution.START_EPOCH_KEY, currentBatchId.toString)
+      ContinuousExecution.START_EPOCH_KEY, execCtx.batchId.toString)
     // Add another random ID on top of the run ID, to distinguish epoch coordinators across
     // reconfigurations.
     val epochCoordinatorId = s"$runId--${UUID.randomUUID}"
@@ -250,14 +260,14 @@ class ContinuousExecution(
       stream,
       this,
       epochCoordinatorId,
-      currentBatchId,
+      execCtx.batchId,
       sparkSession,
       SparkEnv.get)
     val epochUpdateThread = new Thread(new Runnable {
       override def run: Unit = {
         try {
-          triggerExecutor.execute(() => {
-            startTrigger()
+          triggerExecutor.execute((_) => {
+            execCtx.startTrigger()
 
             if (stream.needsReconfiguration && state.compareAndSet(ACTIVE, RECONFIGURING)) {
               if (queryExecutionThread.isAlive) {
@@ -265,8 +275,8 @@ class ContinuousExecution(
               }
               false
             } else if (isActive) {
-              currentBatchId = epochEndpoint.askSync[Long](IncrementAndGetEpoch)
-              logInfo(s"New epoch $currentBatchId is starting.")
+              execCtx.batchId = epochEndpoint.askSync[Long](IncrementAndGetEpoch)
+              logInfo(s"New epoch ${execCtx.batchId} is starting.")
               true
             } else {
               false
@@ -283,8 +293,8 @@ class ContinuousExecution(
       epochUpdateThread.setDaemon(true)
       epochUpdateThread.start()
 
-      updateStatusMessage("Running")
-      reportTimeTaken("runContinuous") {
+      execCtx.updateStatusMessage("Running")
+      execCtx.reportTimeTaken("runContinuous") {
         SQLExecution.withNewExecutionId(lastExecution) {
           lastExecution.executedPlan.execute()
         }
@@ -359,14 +369,18 @@ class ContinuousExecution(
    * before this is called.
    */
   def commit(epoch: Long): Unit = {
-    updateStatusMessage(s"Committing epoch $epoch")
+    val execCtx = latestExecutionContext.asInstanceOf[ContinuousExecutionContext]
+    execCtx.updateStatusMessage(s"Committing epoch $epoch")
 
     assert(sources.length == 1, "only one continuous source supported currently")
     assert(offsetLog.get(epoch).isDefined, s"offset for epoch $epoch not reported before commit")
 
     synchronized {
       // Record offsets before updating `committedOffsets`
-      recordTriggerOffsets(from = committedOffsets, to = availableOffsets, latest = latestOffsets)
+      execCtx.recordTriggerOffsets(
+        from = execCtx.startOffsets,
+        to = execCtx.endOffsets,
+        latest = execCtx.latestOffsets)
       if (queryExecutionThread.isAlive) {
         commitLog.add(epoch, CommitMetadata())
         val offset =
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala
index 110b562be1f7..bb7f8fc98d60 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala
@@ -51,7 +51,7 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite with TimeLimits {
     val executor = ProcessingTimeExecutor(ProcessingTimeTrigger("1000 milliseconds"), clock)
     val executorThread = new Thread() {
       override def run(): Unit = {
-        executor.execute(() => {
+        executor.execute((_) => {
           // Record the trigger time, increment clock if needed and
           triggerTimes.add(clock.getTimeMillis().toInt)
           clock.advance(clockIncrementInTrigger)
@@ -111,7 +111,7 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite with TimeLimits {
   private def testBatchTermination(intervalMs: Long): Unit = {
     var batchCounts = 0
     val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTimeTrigger(intervalMs))
-    processingTimeExecutor.execute(() => {
+    processingTimeExecutor.execute((_) => {
       batchCounts += 1
       // If the batch termination works correctly, batchCounts should be 3 after `execute`
       batchCounts < 3
@@ -134,7 +134,7 @@ class ProcessingTimeExecutorSuite extends SparkFunSuite with TimeLimits {
             batchFallingBehindCalled = true
           }
         }
-        processingTimeExecutor.execute(() => {
+        processingTimeExecutor.execute((_) => {
           clock.waitTillTime(200)
           false
         })


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org