You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2022/11/05 04:33:35 UTC

[GitHub] [spark] jerrypeng opened a new pull request, #38517: [WIP][SPARK-39591][SS] Async Progress Tracking

jerrypeng opened a new pull request, #38517:
URL: https://github.com/apache/spark/pull/38517

   <!--
   Thanks for sending a pull request!  Here are some tips for you:
     1. If this is your first time, please read our contributor guidelines: https://spark.apache.org/contributing.html
     2. Ensure you have added or run the appropriate tests for your PR: https://spark.apache.org/developer-tools.html
     3. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP][SPARK-XXXX] Your PR title ...'.
     4. Be sure to keep the PR description updated to reflect all changes.
     5. Please write your PR title to summarize what this PR proposes.
     6. If possible, provide a concise example to reproduce the issue for a faster review.
     7. If you want to add a new configuration, please read the guideline first for naming configurations in
        'core/src/main/scala/org/apache/spark/internal/config/ConfigEntry.scala'.
     8. If you want to add or modify an error type or message, please read the guideline first in
        'core/src/main/resources/error/README.md'.
   -->
   
   ### What changes were proposed in this pull request?
   <!--
   Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. 
   If possible, please consider writing useful notes for better and faster reviews in your PR. See the examples below.
     1. If you refactor some codes with changing classes, showing the class hierarchy will help reviewers.
     2. If you fix some SQL features, you can provide some references of other DBMSes.
     3. If there is design documentation, please add the link.
     4. If there is a discussion in the mailing list, please add the link.
   -->
   
   
   ### Why are the changes needed?
   <!--
   Please clarify why the changes are needed. For instance,
     1. If you propose a new API, clarify the use case for a new API.
     2. If you fix a bug, you can clarify why it is a bug.
   -->
   
   
   ### Does this PR introduce _any_ user-facing change?
   <!--
   Note that it means *any* user-facing change including all aspects such as the documentation fix.
   If yes, please clarify the previous behavior and the change this PR proposes - provide the console output, description and/or an example to show the behavior difference if possible.
   If possible, please also clarify if this is a user-facing change compared to the released Spark versions or within the unreleased branches such as master.
   If no, write 'No'.
   -->
   
   
   ### How was this patch tested?
   <!--
   If tests were added, say they were added here. Please make sure to add some test cases that check the changes thoroughly including negative and positive cases if possible.
   If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future.
   If tests were not added, please describe why they were not added and/or why it was difficult to add.
   If benchmark tests were added, please run the benchmarks in GitHub Actions for the consistent environment, and the instructions could accord to: https://spark.apache.org/developer-tools.html#github-workflow-benchmarks.
   -->
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048793733


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncCommitLog.scala:
##########
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.OutputStream
+import java.util.concurrent.{CompletableFuture, ConcurrentLinkedDeque, ThreadPoolExecutor}
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.SparkSession
+
+/**
+ * Implementation of CommitLog to perform asynchronous writes to storage
+ */
+class AsyncCommitLog(sparkSession: SparkSession, path: String, executorService: ThreadPoolExecutor)
+    extends CommitLog(sparkSession, path) {
+

Review Comment:
   ok



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049027325


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {

Review Comment:
   ok



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049023076


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(

Review Comment:
   sure



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052533233


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncCommitLog.scala:
##########
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.OutputStream
+import java.util.concurrent.{CompletableFuture, ConcurrentLinkedDeque, ThreadPoolExecutor}
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.SparkSession
+
+/**
+ * Implementation of CommitLog to perform asynchronous writes to storage
+ */
+class AsyncCommitLog(sparkSession: SparkSession, path: String, executorService: ThreadPoolExecutor)
+  extends CommitLog(sparkSession, path) {
+
+  // the cache needs to be enabled because we may not be persisting every entry to durable storage
+  // entries not persisted to durable storage will just be stored in memory for faster lookups
+  assert(metadataCacheEnabled == true)
+
+  // A queue of batches written to storage.  Used to keep track when to purge old batches
+  val writtenToDurableStorage =
+    new ConcurrentLinkedDeque[Long](listBatchesOnDisk.toList.asJavaCollection)
+
+  /**
+   * Writes a new batch to the commit log asynchronously
+   * @param batchId id of batch to write
+   * @param metadata metadata of batch to write
+   * @return a CompeletableFuture that contains the batch id.  The future is completed when
+   *         the async write of the batch is completed.  Future may also be completed exceptionally
+   *         to indicate some write error.
+   */
+  def addAsync(batchId: Long, metadata: CommitMetadata): CompletableFuture[Long] = {
+    require(metadata != null, "'null' metadata cannot be written to a metadata log")
+    val future: CompletableFuture[Long] = addNewBatchByStreamAsync(batchId) { output =>
+      serialize(metadata, output)
+    }.thenApply((ret: Boolean) => {
+      if (ret) {
+        batchId
+      } else {
+        throw new IllegalStateException(
+          s"Concurrent update to the log. Multiple streaming jobs detected for $batchId"
+        )
+      }
+    })
+
+    batchCache.put(batchId, metadata)
+    future
+  }
+
+  /**
+   * Adds batch to commit log only in memory and not persisted to durable storage. This method is
+   * used when we don't want to persist the commit log entry for every micro batch
+   * to durable storage
+   * @param batchId id of batch to write
+   * @param metadata metadata of batch to write
+   * @return true if operation is successful otherwise false.
+   */
+  def addInMemory(batchId: Long, metadata: CommitMetadata): Boolean = {
+    if (batchCache.containsKey(batchId)) {
+      false
+    } else {
+      batchCache.put(batchId, metadata)
+      true
+    }
+  }
+
+  /**
+   * Purge entries in the commit log up to thresholdBatchId.
+   * @param thresholdBatchId
+   */
+  override def purge(thresholdBatchId: Long): Unit = {
+    super.purge(thresholdBatchId)
+  }
+
+  /**
+   * Adds new batch asynchronously
+   * @param batchId id of batch to write
+   * @param fn serialization function
+   * @return CompletableFuture that contains a boolean do
+   *         indicate whether the write was successfuly or not.
+   *         Future can also be completed exceptionally to indicate write errors.
+   */
+  private def addNewBatchByStreamAsync(batchId: Long)(
+    fn: OutputStream => Unit): CompletableFuture[Boolean] = {
+    val future = new CompletableFuture[Boolean]()
+    val batchMetadataFile = batchIdToPath(batchId)
+
+    if (batchCache.containsKey(batchId)) {
+      future.complete(false)
+      future
+    } else {
+      executorService.submit(new Runnable {
+        override def run(): Unit = {
+          try {
+            if (fileManager.exists(batchMetadataFile)) {
+              future.complete(false)
+            } else {
+              val start = System.currentTimeMillis()
+              write(
+                batchMetadataFile,
+                fn
+              )
+              logDebug(
+                s"Completion commit for batch${batchId} took" +

Review Comment:
   will fix



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052535390


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,285 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+  sparkSession: SparkSession,

Review Comment:
   will fix



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052712279


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/offsets/3").delete()
+    new File(checkpointLocation + "/offsets/.3.crc").delete()
+    new File(checkpointLocation + "/offsets/4").delete()
+    new File(checkpointLocation + "/offsets/.4.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.4.crc").delete()
+    new File(checkpointLocation + "/commits/4").delete()
+    new File(checkpointLocation + "/commits/.5.crc").delete()
+    new File(checkpointLocation + "/commits/5").delete()
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      AddData(inputData2, 3),
+      AddData(inputData2, 4),
+      AddData(inputData2, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3, 4, 5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+  }
+
+  test(
+    "recovery when gaps in in offset and" +
+    " commit log"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/commits/2").delete()
+    new File(checkpointLocation + "/commits/.2.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from batch 1
+      CheckNewAnswer(1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+  }
+
+  test("recovery non-contiguous log") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // delete batch 3 from commit log
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * Turn async progress tracking off and test recovery
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 4),
+      CheckNewAnswer(1, 2, 3, 4),

Review Comment:
   > If Spark processes 1 -> 3 and then 3 -> 4, 1 -> 3 would be batch 3, ignoring the offset range logged in offset log for batch 3, right?
   
   That is correct.  
   
   > So when they switch to normal processing trigger, the fault tolerant semantic is still bound to the async progress tracking, "at least once", till the switched query gets processed and offset log becomes contiguous. It does not seem to be easy to reason about.
   
   Switching to async progress tracking already indicates that the user is ok with at least once semantics thus it shouldn't be a surprise that even after switching there could be some duplicates initially. 
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052721653


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala:
##########
@@ -342,17 +342,14 @@ class MicroBatchExecution(
         isCurrentBatchConstructed = true
         availableOffsets = nextOffsets.toStreamProgress(sources)
         /* Initialize committed offsets to a committed batch, which at this
-         * is the second latest batch id in the offset log. */
-        if (latestBatchId != 0) {
-          val secondLatestOffsets = offsetLog.get(latestBatchId - 1).getOrElse {
-            logError(s"The offset log for batch ${latestBatchId - 1} doesn't exist, " +
-              s"which is required to restart the query from the latest batch $latestBatchId " +
-              "from the offset log. Please ensure there are two subsequent offset logs " +
-              "available for the latest batch via manually deleting the offset file(s). " +
-              "Please also ensure the latest batch for commit log is equal or one batch " +
-              "earlier than the latest batch for offset log.")
-            throw new IllegalStateException(s"batch ${latestBatchId - 1} doesn't exist")
-          }
+         * is the second latest batch id in the offset log.

Review Comment:
   This logic can affect the offset range of microbatch. As you've added the test, even without async progress tracking flag on, normal processing trigger can technically roll multiple microbatches back, "with composing these offsets into one". This breaks the assumption of exactly-once semantic, every microbatch should have planned its offset range before execution, and the range must not be changed once planned.
   
   This is why async progress tracking cannot work as it is for Delta sink and stateful operator. We blocked this for async progress tracking, but accidentally exposing this to "normal" processing trigger.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1051727091


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))

Review Comment:
   data should equal((0 to 9).toArray) <= This succeeds in my local.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052660233


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")

Review Comment:
   ok



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on PR #38517:
URL: https://github.com/apache/spark/pull/38517#issuecomment-1354530532

   I'm not in favor of introducing change against normal microbatch execution. As I commented, it seems to me as high risk one.
   
   Instead of smooth transition for all cases, I'd say we should just support transition from async to sync only when the checkpoint interval is set to 0, say, when the query has contiguous offsets. If they want to switch the mode for the query which uses checkpoint interval, rerun the query with async mode with setting up checkpoint interval to 0, and see the query running for a couple of batches, and terminate. This will also simply work for the rollback case of Spark version. Transition from sync to async is always feasible, as long as the query does not touch the limitations.
   
   What do you think?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on PR #38517:
URL: https://github.com/apache/spark/pull/38517#issuecomment-1322908492

   @HeartSaVioR Please review.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049071823


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()

Review Comment:
   why? We do't need to get all of the files



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049014203


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {

Review Comment:
   easier to read



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048939625


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+

Review Comment:
   ok



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048934225


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {

Review Comment:
   I am not sure I understand. This object only contains helper methods and variables that are static



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049093714


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/offsets/3").delete()
+    new File(checkpointLocation + "/offsets/.3.crc").delete()
+    new File(checkpointLocation + "/offsets/4").delete()
+    new File(checkpointLocation + "/offsets/.4.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.4.crc").delete()
+    new File(checkpointLocation + "/commits/4").delete()
+    new File(checkpointLocation + "/commits/.5.crc").delete()
+    new File(checkpointLocation + "/commits/5").delete()
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      AddData(inputData2, 3),
+      AddData(inputData2, 4),
+      AddData(inputData2, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3, 4, 5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+  }
+
+  test(
+    "recovery when gaps in in offset and" +
+    " commit log"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/commits/2").delete()
+    new File(checkpointLocation + "/commits/.2.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from batch 1
+      CheckNewAnswer(1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+  }
+
+  test("recovery non-contiguous log") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // delete batch 3 from commit log
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * Turn async progress tracking off and test recovery
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 4),
+      CheckNewAnswer(1, 2, 3, 4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+  }
+
+  test("switching async progress tracking with interval commits on and off") {
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    var clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    inputData.commits.length should be(0)
+
+    /**
+     * Turn async progress tracking off
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(3).json)
+        inputData.commits.clear()
+      },
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(4).json)
+        inputData.commits.clear()
+      },
+      StopStream
+    )
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+    /**
+     * Turn async progress tracking on
+     */
+    clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 6), // will get persisted since first batch on restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      AddData(inputData, 7),
+      AdvanceManualClock(100),
+      CheckNewAnswer(7),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(5).json)
+        inputData.commits.clear()
+      },
+      AddData(inputData, 8),
+      AdvanceManualClock(100),
+      CheckNewAnswer(8),
+      AddData(inputData, 9),
+      AdvanceManualClock(800), //  persist offset
+      CheckNewAnswer(9),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5, 6, 9))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5, 6, 9))
+
+    // simulate batch 9 doesn't have a commit log
+    new File(checkpointLocation + "/commits/9").delete()
+    new File(checkpointLocation + "/commits/.9.crc").delete()
+
+    /**
+     * Turn async progress tracking off
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 10),
+      CheckNewAnswer(7, 8, 9, 10),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(9).json)
+        inputData.commits.clear()
+      },
+      AddData(inputData, 11),
+      CheckNewAnswer(11),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(10).json)
+        inputData.commits.clear()
+      },
+      StopStream
+    )
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5, 6, 9, 10, 11))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5, 6, 9, 10, 11))
+  }
+
+  test("Fail on stateful pipelines") {
+    val rateStream = spark.readStream
+      .format("rate")
+      .option("numPartitions", 1)
+      .option("rowsPerSecond", 10)
+      .load()
+      .toDF()
+
+    val windowedStream = rateStream
+      .withWatermark("timestamp", "0 seconds")
+      .groupBy(window(column("timestamp"), "10 seconds"), column("value"))
+      .count()
+
+    val e = intercept[StreamingQueryException] {
+      val query = windowedStream.writeStream
+        .format("noop")
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+
+      query.processAllAvailable()
+    }
+    e.getMessage should include(

Review Comment:
   ok



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048790837


##########
core/src/main/scala/org/apache/spark/util/ThreadUtils.scala:
##########
@@ -167,6 +167,27 @@ private[spark] object ThreadUtils {
     Executors.newFixedThreadPool(1, threadFactory).asInstanceOf[ThreadPoolExecutor]
   }
 
+  /**
+   * Wrapper over newSingleThreadExecutor that allows the specification
+   * of a RejectedExecutionHandler
+   */
+  def newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    threadName: String,

Review Comment:
   The other methods in this class only have 2 spaces?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048791828


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala:
##########
@@ -2014,7 +2014,6 @@ object SQLConf {
       .booleanConf
       .createWithDefault(true)
 
-

Review Comment:
   OK



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049075731


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/offsets/3").delete()
+    new File(checkpointLocation + "/offsets/.3.crc").delete()
+    new File(checkpointLocation + "/offsets/4").delete()
+    new File(checkpointLocation + "/offsets/.4.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.4.crc").delete()
+    new File(checkpointLocation + "/commits/4").delete()
+    new File(checkpointLocation + "/commits/.5.crc").delete()
+    new File(checkpointLocation + "/commits/5").delete()
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      AddData(inputData2, 3),
+      AddData(inputData2, 4),
+      AddData(inputData2, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3, 4, 5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+  }
+
+  test(
+    "recovery when gaps in in offset and" +
+    " commit log"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty

Review Comment:
   will correct



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049070725


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+

Review Comment:
   ok



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048789799


##########
connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala:
##########
@@ -195,6 +200,102 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase {
     true
   }
 
+  /**
+   * Test async progress tracking capability with Kafka source and sink
+   */
+  test("async progress tracking") {
+    val inputTopic = newTopic()
+    testUtils.createTopic(inputTopic, partitions = 5)
+
+    val dataSent = new ListBuffer[String]()
+    testUtils.sendMessages(inputTopic, (0 until 15).map { case x =>
+      val m = s"foo-$x"
+      dataSent += m
+      m
+    }.toArray, Some(0))
+
+    val outputTopic = newTopic()
+    testUtils.createTopic(outputTopic, partitions = 5)
+
+    withTempDir { dir =>
+      val reader = spark
+        .readStream
+        .format("kafka")
+        .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+        .option("kafka.metadata.max.age.ms", "1")
+        .option("maxOffsetsPerTrigger", 5)
+        .option("subscribe", inputTopic)
+        .option("startingOffsets", "earliest")
+        .load()
+
+      def startQuery(): StreamingQuery = {
+        reader.writeStream
+          .format("kafka")
+          .option("checkpointLocation", dir.getCanonicalPath)
+          .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+          .option("kafka.max.block.ms", "5000")
+          .option("topic", outputTopic)
+          .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+          .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+          .queryName("kafkaStream")
+          .start()
+      }
+
+      def readResults(): ListBuffer[String] = {
+        val data = new ListBuffer[String]()
+        val readQuery = spark.readStream
+          .format("kafka")
+          .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+          .option("startingOffsets", "earliest")
+          .option("subscribe", outputTopic)
+          .load().writeStream
+          .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+            ds.collect.foreach((row: Row) => {
+              val v: String = new String(row.getAs("value").asInstanceOf[Array[Byte]])
+              data += v
+            }: Unit)
+          }).start()
+
+        try {
+          readQuery.processAllAvailable()
+        } finally {
+          readQuery.stop()
+        }
+        data
+      }
+
+      val query = startQuery()
+      try {
+        query.processAllAvailable()
+      } finally {
+        query.stop()
+      }
+
+      val data = readResults()
+      data should equal (dataSent)
+
+      /**

Review Comment:
   ok



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049102400


##########
core/src/main/scala/org/apache/spark/util/ThreadUtils.scala:
##########
@@ -167,6 +167,27 @@ private[spark] object ThreadUtils {
     Executors.newFixedThreadPool(1, threadFactory).asInstanceOf[ThreadPoolExecutor]
   }
 
+  /**
+   * Wrapper over newSingleThreadExecutor that allows the specification
+   * of a RejectedExecutionHandler
+   */
+  def newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    threadName: String,

Review Comment:
   ok



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049023076


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(

Review Comment:
   Lets keep the existing so that we can monitor if the level of nesting remains the same



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048961837


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala:
##########
@@ -300,10 +322,8 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path:
     }
   }
 
-
   /**
-   * List the available batches on file system. As a workaround for S3 inconsistent list, it also
-   * tries to take `batchCache` into consideration to infer a better answer.
+   * List the available batches on file system

Review Comment:
   ok



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048930639


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncOffsetSeqLog.scala:
##########
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.OutputStream
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.util.{Clock, SystemClock}
+
+/**
+ * Used to write entries to the offset log asynchronously
+ */
+class AsyncOffsetSeqLog(
+    sparkSession: SparkSession,
+    path: String,
+    executorService: ThreadPoolExecutor,
+    offsetCommitIntervalMs: Long,
+    clock: Clock = new SystemClock())
+    extends OffsetSeqLog(sparkSession, path) {
+
+  // the cache needs to be enabled because we may not be persisting every entry to durable storage
+  // entries not persisted to durable storage will just be stored in memory for faster lookups
+  assert(metadataCacheEnabled == true)
+
+  // A map of the current pending offset writes. Key -> batch Id, Value -> CompletableFuture
+  // Used to determine if a commit log entry for this batch also needs to be persisted to storage
+  private val pendingOffsetWrites = new ConcurrentHashMap[Long, CompletableFuture[Long]]()
+
+  // Keeps track the last time a commit was issued. Used for issuing commits to storage at
+  // the configured intervals
+  private val lastCommitIssuedTimestampMs: AtomicLong = new AtomicLong(-1)
+
+  // A queue of batches written to storage.  Used to keep track when to purge old batches
+  val writtenToDurableStorage =
+    new ConcurrentLinkedDeque[Long](listBatchesOnDisk.toList.asJavaCollection)
+
+  /**
+   * Get a async offset write by batch id.  To check if a corresponding commit log entry
+   * needs to be written to durable storage as well
+   * @param batchId
+   * @return a option to indicate whether a async offset write was issued for the batch with id
+   */
+  def getAsyncOffsetWrite(batchId: Long): Option[CompletableFuture[Long]] = {
+    Option(pendingOffsetWrites.get(batchId))
+  }
+
+  /**
+   * Remove the async offset write when we don't need to keep track of it anymore
+   * @param batchId
+   */
+  def removeAsyncOffsetWrite(batchId: Long): Unit = {
+    pendingOffsetWrites.remove(batchId)
+  }
+
+  /**
+   * Writes a new batch to the offset log asynchronously
+   * @param batchId id of batch to write
+   * @param metadata metadata of batch to write
+   * @return a CompeletableFuture that contains the batch id.  The future is completed when
+   *         the async write of the batch is completed.  Future may also be completed exceptionally
+   *         to indicate some write error.
+   */
+  def addAsync(batchId: Long, metadata: OffsetSeq): CompletableFuture[(Long, Boolean)] = {
+    require(metadata != null, "'null' metadata cannot written to a metadata log")
+
+    def issueAsyncWrite(batchId: Long): CompletableFuture[Long] = {
+      lastCommitIssuedTimestampMs.set(clock.getTimeMillis())
+      val future: CompletableFuture[Long] = addNewBatchByStreamAsync(batchId) { output =>
+        serialize(metadata, output)
+      }.thenApply((ret: Boolean) => {
+        if (ret) {
+          batchId
+        } else {
+          throw new IllegalStateException(
+            s"Concurrent update to the log. Multiple streaming jobs detected for $batchId"

Review Comment:
   This is the message in existing implementation. We don't really differentiate which log. However, you should know which log it is based on the stack trace.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048926395


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncCommitLog.scala:
##########
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.OutputStream
+import java.util.concurrent.{CompletableFuture, ConcurrentLinkedDeque, ThreadPoolExecutor}
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.SparkSession
+
+/**
+ * Implementation of CommitLog to perform asynchronous writes to storage
+ */
+class AsyncCommitLog(sparkSession: SparkSession, path: String, executorService: ThreadPoolExecutor)
+    extends CommitLog(sparkSession, path) {
+
+  // A queue of batches written to storage.  Used to keep track when to purge old batches
+  val writtenToDurableStorage =
+    new ConcurrentLinkedDeque[Long](listBatchesOnDisk.toList.asJavaCollection)
+
+  /**
+   * Writes a new batch to the commit log asynchronously
+   * @param batchId id of batch to write
+   * @param metadata metadata of batch to write
+   * @return a CompeletableFuture that contains the batch id.  The future is completed when
+   *         the async write of the batch is completed.  Future may also be completed exceptionally
+   *         to indicate some write error.
+   */
+  def addAsync(batchId: Long, metadata: CommitMetadata): CompletableFuture[Long] = {
+    require(metadata != null, "'null' metadata cannot be written to a metadata log")
+    val future: CompletableFuture[Long] = addNewBatchByStreamAsync(batchId) { output =>
+      serialize(metadata, output)
+    }.thenApply((ret: Boolean) => {
+      if (ret) {
+        batchId
+      } else {
+        throw new IllegalStateException(
+          s"Concurrent update to the log. Multiple streaming jobs detected for $batchId"
+        )
+      }
+    })
+

Review Comment:
   It is missing but we have it in our internal implementation. Not sure how this was missed. will add back



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048930422


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncOffsetSeqLog.scala:
##########
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.OutputStream
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.util.{Clock, SystemClock}
+
+/**
+ * Used to write entries to the offset log asynchronously
+ */
+class AsyncOffsetSeqLog(
+    sparkSession: SparkSession,
+    path: String,
+    executorService: ThreadPoolExecutor,
+    offsetCommitIntervalMs: Long,
+    clock: Clock = new SystemClock())
+    extends OffsetSeqLog(sparkSession, path) {

Review Comment:
   ok



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052662233


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)

Review Comment:
   its an extra check to make sure that even after we instruct the query to stop no additional or unexpected modifications to the offset log and commit log are done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049119941


##########
connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala:
##########
@@ -195,6 +200,102 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase {
     true
   }
 
+  /**
+   * Test async progress tracking capability with Kafka source and sink
+   */
+  test("async progress tracking") {
+    val inputTopic = newTopic()
+    testUtils.createTopic(inputTopic, partitions = 5)
+
+    val dataSent = new ListBuffer[String]()
+    testUtils.sendMessages(inputTopic, (0 until 15).map { case x =>
+      val m = s"foo-$x"
+      dataSent += m
+      m
+    }.toArray, Some(0))
+
+    val outputTopic = newTopic()
+    testUtils.createTopic(outputTopic, partitions = 5)
+
+    withTempDir { dir =>
+      val reader = spark
+        .readStream
+        .format("kafka")
+        .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+        .option("kafka.metadata.max.age.ms", "1")
+        .option("maxOffsetsPerTrigger", 5)
+        .option("subscribe", inputTopic)
+        .option("startingOffsets", "earliest")
+        .load()
+
+      def startQuery(): StreamingQuery = {
+        reader.writeStream
+          .format("kafka")
+          .option("checkpointLocation", dir.getCanonicalPath)
+          .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+          .option("kafka.max.block.ms", "5000")
+          .option("topic", outputTopic)
+          .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+          .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+          .queryName("kafkaStream")
+          .start()
+      }
+
+      def readResults(): ListBuffer[String] = {

Review Comment:
   The output would be same but the code and actual execution would be much simpler in batch query. See below code when we just go with batch query:
   
   ```
   val data = spark.read.format("kafka")...load().select("CAST(value AS string)").toDS().collect().map(_._1)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048946439


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())
+            if (prevBatchOff.isDefined) {
+              // Offset is ready to be committed by the source. Add to queue
+              sourceCommitQueue.add(prevBatchOff.get)
+            } else {
+              throw new IllegalStateException(
+                s"batch ${lastBatchPersistedToDurableStorage.get()} doesn't exist"

Review Comment:
   will improve error message



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049099509


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/offsets/3").delete()
+    new File(checkpointLocation + "/offsets/.3.crc").delete()
+    new File(checkpointLocation + "/offsets/4").delete()
+    new File(checkpointLocation + "/offsets/.4.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.4.crc").delete()
+    new File(checkpointLocation + "/commits/4").delete()
+    new File(checkpointLocation + "/commits/.5.crc").delete()
+    new File(checkpointLocation + "/commits/5").delete()
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      AddData(inputData2, 3),
+      AddData(inputData2, 4),
+      AddData(inputData2, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3, 4, 5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+  }
+
+  test(
+    "recovery when gaps in in offset and" +
+    " commit log"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/commits/2").delete()
+    new File(checkpointLocation + "/commits/.2.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from batch 1
+      CheckNewAnswer(1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+  }
+
+  test("recovery non-contiguous log") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // delete batch 3 from commit log
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * Turn async progress tracking off and test recovery
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 4),
+      CheckNewAnswer(1, 2, 3, 4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+  }
+
+  test("switching async progress tracking with interval commits on and off") {
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    var clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    inputData.commits.length should be(0)
+
+    /**
+     * Turn async progress tracking off
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(3).json)
+        inputData.commits.clear()
+      },
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(4).json)
+        inputData.commits.clear()
+      },
+      StopStream
+    )
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+    /**
+     * Turn async progress tracking on
+     */
+    clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 6), // will get persisted since first batch on restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      AddData(inputData, 7),
+      AdvanceManualClock(100),
+      CheckNewAnswer(7),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(5).json)
+        inputData.commits.clear()
+      },
+      AddData(inputData, 8),
+      AdvanceManualClock(100),
+      CheckNewAnswer(8),
+      AddData(inputData, 9),
+      AdvanceManualClock(800), //  persist offset
+      CheckNewAnswer(9),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5, 6, 9))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5, 6, 9))
+
+    // simulate batch 9 doesn't have a commit log
+    new File(checkpointLocation + "/commits/9").delete()
+    new File(checkpointLocation + "/commits/.9.crc").delete()
+
+    /**
+     * Turn async progress tracking off
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 10),
+      CheckNewAnswer(7, 8, 9, 10),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(9).json)
+        inputData.commits.clear()
+      },
+      AddData(inputData, 11),
+      CheckNewAnswer(11),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(10).json)
+        inputData.commits.clear()
+      },
+      StopStream
+    )
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5, 6, 9, 10, 11))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5, 6, 9, 10, 11))
+  }
+
+  test("Fail on stateful pipelines") {
+    val rateStream = spark.readStream
+      .format("rate")
+      .option("numPartitions", 1)
+      .option("rowsPerSecond", 10)
+      .load()
+      .toDF()
+
+    val windowedStream = rateStream
+      .withWatermark("timestamp", "0 seconds")
+      .groupBy(window(column("timestamp"), "10 seconds"), column("value"))
+      .count()
+
+    val e = intercept[StreamingQueryException] {
+      val query = windowedStream.writeStream
+        .format("noop")
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+
+      query.processAllAvailable()
+    }
+    e.getMessage should include(
+      "Stateful streaming queries does not support async progress tracking at this moment."
+    )
+  }
+
+  test("Fail on pipelines using unsupported sinks") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("parquet")
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(
+          "checkpointLocation",
+          Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+        )
+        .start("/tmp")
+    }
+
+    e.getMessage should equal("Sink FileSink[/tmp] does not support async progress tracking")
+  }
+
+  test("with log purging") {
+
+    withSQLConf(SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2", SQLConf.ASYNC_LOG_PURGE.key -> "false") {
+      withTempDir { checkpointLocation =>
+        val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+        val ds = inputData.toDS()
+
+        val clock = new StreamManualClock
+        testStream(
+          ds,
+          extraOptions = Map(
+            ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+            ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+          )
+        )(
+          // need to to set processing time to something so manual clock will work
+          StartStream(
+            Trigger.ProcessingTime("1 millisecond"),
+            checkpointLocation = checkpointLocation.getCanonicalPath,
+            triggerClock = clock
+          ),
+          AddData(inputData, 0),
+          AdvanceManualClock(100),
+          CheckNewAnswer(0),
+          AddData(inputData, 1),
+          AdvanceManualClock(100),
+          CheckNewAnswer(1),
+          AddData(inputData, 2),
+          AdvanceManualClock(100),
+          CheckNewAnswer(2),
+          AddData(inputData, 3),
+          AdvanceManualClock(800), // should trigger offset commit write to durable storage
+          CheckNewAnswer(3),
+          AddData(inputData, 4),
+          AdvanceManualClock(100),
+          CheckNewAnswer(4),
+          AddData(inputData, 5),
+          AdvanceManualClock(100),
+          CheckNewAnswer(5),
+          AddData(inputData, 6),
+          AdvanceManualClock(100),
+          CheckNewAnswer(6),
+          AddData(inputData, 7),
+          AdvanceManualClock(800), // should trigger offset commit write to durable storage
+          CheckNewAnswer(7),
+          Execute { q =>
+            // wait for all async log writes to finish
+            waitPendingOffsetWrites(q)
+          },
+          // add a new row to make sure log purge has kicked in.
+          // There can be a race condition in which the commit log entry for the previous batch
+          // may or may not be written to disk yet before the log purge is called.
+          // Adding another batch here will make sure purge is called on the correct number of
+          // offset and commit log entries
+          AddData(inputData, 8),
+          AdvanceManualClock(100),
+          CheckNewAnswer(8),
+          Execute { q =>
+            waitPendingOffsetWrites(q)
+            getListOfFiles(checkpointLocation + "/offsets")
+              .filter(file => !file.isHidden)
+              .map(file => file.getName.toInt)
+              .sorted should equal(Array(3, 7))
+            getListOfFiles(checkpointLocation + "/commits")
+              .filter(file => !file.isHidden)
+              .map(file => file.getName.toInt)
+              .sorted should equal(Array(3, 7))
+            q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].writtenToDurableStorage.size() should be(2)
+            q.commitLog.asInstanceOf[AsyncCommitLog].writtenToDurableStorage.size() should be(2)
+          },
+          StopStream
+        )
+
+        /**
+         * restart
+         */
+        testStream(
+          ds,
+          extraOptions = Map(
+            ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+            ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+          )
+        )(
+          // need to to set processing time to something so manual clock will work
+          StartStream(
+            Trigger.ProcessingTime("1 millisecond"),
+            checkpointLocation = checkpointLocation.getCanonicalPath,
+            triggerClock = clock
+          ),
+          AddData(inputData, 9),
+          AdvanceManualClock(100),
+          CheckNewAnswer(8, 9),
+          AddData(inputData, 10),
+          AdvanceManualClock(100),
+          CheckNewAnswer(10),
+          AddData(inputData, 11),
+          AdvanceManualClock(100),
+          CheckNewAnswer(11),
+          AddData(inputData, 12),
+          AdvanceManualClock(800), // should trigger offset commit write to durable storage
+          CheckNewAnswer(12),
+          Execute { q =>
+            // wait for all async log writes to finish
+            waitPendingOffsetWrites(q)
+          },
+          // add a new row to make sure log purge has kicked in.
+          // There can be a race condition in which the commit log entry for the previous batch
+          // may or may not be written to disk yet before the log purge is called.
+          // Adding another batch here will make sure purge is called on the correct number of
+          // offset and commit log entries
+          AddData(inputData, 13),
+          AdvanceManualClock(100),
+          CheckNewAnswer(13),
+          Execute { q =>
+            waitPendingOffsetWrites(q)
+            getListOfFiles(checkpointLocation + "/offsets")
+              .filter(file => !file.isHidden)
+              .map(file => file.getName.toInt)
+              .sorted should equal(Array(8, 12))
+            getListOfFiles(checkpointLocation + "/commits")
+              .filter(file => !file.isHidden)
+              .map(file => file.getName.toInt)
+              .sorted should equal(Array(8, 12))
+
+            q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].writtenToDurableStorage.size() should be(2)
+            q.commitLog.asInstanceOf[AsyncCommitLog].writtenToDurableStorage.size() should be(2)
+          },
+          StopStream
+        )
+      }
+    }
+  }
+
+  test("with async log purging") {
+    withSQLConf(SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2", SQLConf.ASYNC_LOG_PURGE.key -> "true") {
+      withTempDir { checkpointLocation =>
+        val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+        val ds = inputData.toDS()
+
+        val clock = new StreamManualClock
+        testStream(
+          ds,
+          extraOptions = Map(
+            ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+            ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+          )
+        )(
+          // need to to set processing time to something so manual clock will work
+          StartStream(
+            Trigger.ProcessingTime("1 millisecond"),
+            checkpointLocation = checkpointLocation.getCanonicalPath,
+            triggerClock = clock
+          ),
+          AddData(inputData, 0),
+          AdvanceManualClock(100),
+          CheckNewAnswer(0),
+          AddData(inputData, 1),
+          AdvanceManualClock(100),
+          CheckNewAnswer(1),
+          AddData(inputData, 2),
+          AdvanceManualClock(100),
+          CheckNewAnswer(2),
+          AddData(inputData, 3),
+          AdvanceManualClock(800), // should trigger offset commit write to durable storage
+          Execute { q =>
+            // wait for async log writes to complete
+            waitPendingOffsetWrites(q)
+            eventually(timeout(Span(5, Seconds))) {
+              getListOfFiles(checkpointLocation + "/offsets")
+                .filter(file => !file.isHidden)
+                .map(file => file.getName.toInt)
+                .sorted should equal(Array(0, 3))
+
+              getListOfFiles(checkpointLocation + "/commits")
+                .filter(file => !file.isHidden)
+                .map(file => file.getName.toInt)
+                .sorted should equal(Array(0, 3))
+            }
+          },
+          CheckNewAnswer(3),
+          AddData(inputData, 4),
+          AdvanceManualClock(100),
+          CheckNewAnswer(4),
+          AddData(inputData, 5),
+          AdvanceManualClock(100),
+          CheckNewAnswer(5),
+          AddData(inputData, 6),
+          AdvanceManualClock(100),
+          CheckNewAnswer(6),
+          AddData(inputData, 7),
+          AdvanceManualClock(800), // should trigger offset commit write to durable storage
+          CheckNewAnswer(7),
+          Execute { q =>
+            // wait for async log writes to complete
+            waitPendingOffsetWrites(q)
+            // can contain batches 0, 3, 7 or 3, 7
+            eventually(timeout(Span(5, Seconds))) {
+              getListOfFiles(checkpointLocation + "/offsets")
+                .filter(file => !file.isHidden)
+                .map(file => file.getName.toInt)
+                .sorted should contain allElementsOf (Array(3, 7))
+
+              // can contain batches 0, 3, 7 or 3, 7
+              getListOfFiles(checkpointLocation + "/commits")
+                .filter(file => !file.isHidden)
+                .map(file => file.getName.toInt)
+                .sorted should contain allElementsOf (Array(3, 7))
+            }
+          },
+          AddData(inputData, 8),
+          AdvanceManualClock(100),
+          // the commit log entry for batch 7 may not be written yet at this point

Review Comment:
   I will correct the comment.  Lets keep the logic below to make sure the purge actually happened



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048787783


##########
connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala:
##########
@@ -195,6 +200,102 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase {
     true
   }
 
+  /**
+   * Test async progress tracking capability with Kafka source and sink
+   */
+  test("async progress tracking") {
+    val inputTopic = newTopic()
+    testUtils.createTopic(inputTopic, partitions = 5)
+
+    val dataSent = new ListBuffer[String]()
+    testUtils.sendMessages(inputTopic, (0 until 15).map { case x =>
+      val m = s"foo-$x"
+      dataSent += m
+      m
+    }.toArray, Some(0))
+
+    val outputTopic = newTopic()
+    testUtils.createTopic(outputTopic, partitions = 5)
+
+    withTempDir { dir =>
+      val reader = spark
+        .readStream
+        .format("kafka")
+        .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+        .option("kafka.metadata.max.age.ms", "1")
+        .option("maxOffsetsPerTrigger", 5)
+        .option("subscribe", inputTopic)
+        .option("startingOffsets", "earliest")
+        .load()
+
+      def startQuery(): StreamingQuery = {
+        reader.writeStream
+          .format("kafka")
+          .option("checkpointLocation", dir.getCanonicalPath)

Review Comment:
   sure



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048958238


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())
+            if (prevBatchOff.isDefined) {
+              // Offset is ready to be committed by the source. Add to queue
+              sourceCommitQueue.add(prevBatchOff.get)
+            } else {
+              throw new IllegalStateException(
+                s"batch ${lastBatchPersistedToDurableStorage.get()} doesn't exist"
+              )
+            }
+          }
+          lastBatchPersistedToDurableStorage.set(batchId)
+        }
+      })
+      .exceptionally((th: Throwable) => {
+        logError("Encountered error while performing async offset write", th)
+        errorNotifier.markError(th)
+        return
+      })
+
+    // check if there are offsets that are ready to be committed by the source
+    var offset = sourceCommitQueue.poll()
+    while (offset != null) {
+      commitSources(offset)
+      offset = sourceCommitQueue.poll()
+    }
+  }
+
+  override def markMicroBatchEnd(): Unit = {
+    watermarkTracker.updateWatermark(lastExecution.executedPlan)
+    reportTimeTaken("commitOffsets") {
+      // check if current batch there is a async write for the offset log is issued for this batch
+      // if so, we should do the same for commit log
+      if (!offsetLog.getAsyncOffsetWrite(currentBatchId).isEmpty) {
+        commitLog
+          .addAsync(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))
+          .exceptionally((th: Throwable) => {
+            logError("Got exception during async write", th)
+            errorNotifier.markError(th)
+            return
+          })
+      } else {
+        if (!commitLog.addInMemory(
+          currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))) {
+          throw new IllegalStateException(
+            s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId"
+          )
+        }
+      }
+      offsetLog.removeAsyncOffsetWrite(currentBatchId)
+    }
+    committedOffsets ++= availableOffsets
+  }
+
+  // need to look at the number of files on disk
+  override def purge(threshold: Long): Unit = {
+    while (offsetLog.writtenToDurableStorage.size() > minLogEntriesToMaintain) {
+      offsetLog.writtenToDurableStorage.poll()
+    }
+    offsetLog.purge(offsetLog.writtenToDurableStorage.peek())
+
+    while (commitLog.writtenToDurableStorage.size() > minLogEntriesToMaintain) {
+      commitLog.writtenToDurableStorage.poll()
+    }
+    commitLog.purge(commitLog.writtenToDurableStorage.peek())
+  }
+
+  override def cleanup(): Unit = {
+    super.cleanup()
+
+    ThreadUtils.shutdown(asyncWritesExecutorService)
+    logInfo(s"Async progress tracking executor pool for query ${prettyIdString} has been shutdown")
+  }
+
+  // used for testing
+  def areWritesPendingOrInProgress(): Boolean = {
+    asyncWritesExecutorService.getQueue.size() > 0 || asyncWritesExecutorService.getActiveCount > 0
+  }
+
+  private def validateAndGetTrigger(): TriggerExecutor = {
+    // validate that the pipeline is using a supported sink
+    if (!extraOptions
+      .get(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK
+      )
+      .getOrElse("false")
+      .toBoolean) {
+      try {
+        plan.sink.name() match {
+          case "noop-table" =>

Review Comment:
   ok



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049071928


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/offsets/3").delete()
+    new File(checkpointLocation + "/offsets/.3.crc").delete()
+    new File(checkpointLocation + "/offsets/4").delete()
+    new File(checkpointLocation + "/offsets/.4.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.4.crc").delete()
+    new File(checkpointLocation + "/commits/4").delete()
+    new File(checkpointLocation + "/commits/.5.crc").delete()
+    new File(checkpointLocation + "/commits/5").delete()
+
+

Review Comment:
   ok



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049095786


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/offsets/3").delete()
+    new File(checkpointLocation + "/offsets/.3.crc").delete()
+    new File(checkpointLocation + "/offsets/4").delete()
+    new File(checkpointLocation + "/offsets/.4.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.4.crc").delete()
+    new File(checkpointLocation + "/commits/4").delete()
+    new File(checkpointLocation + "/commits/.5.crc").delete()
+    new File(checkpointLocation + "/commits/5").delete()
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      AddData(inputData2, 3),
+      AddData(inputData2, 4),
+      AddData(inputData2, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3, 4, 5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+  }
+
+  test(
+    "recovery when gaps in in offset and" +
+    " commit log"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/commits/2").delete()
+    new File(checkpointLocation + "/commits/.2.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from batch 1
+      CheckNewAnswer(1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+  }
+
+  test("recovery non-contiguous log") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // delete batch 3 from commit log
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * Turn async progress tracking off and test recovery
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 4),
+      CheckNewAnswer(1, 2, 3, 4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+  }
+
+  test("switching async progress tracking with interval commits on and off") {
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    var clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    inputData.commits.length should be(0)
+
+    /**
+     * Turn async progress tracking off
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(3).json)
+        inputData.commits.clear()
+      },
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(4).json)
+        inputData.commits.clear()
+      },
+      StopStream
+    )
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+    /**
+     * Turn async progress tracking on
+     */
+    clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 6), // will get persisted since first batch on restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      AddData(inputData, 7),
+      AdvanceManualClock(100),
+      CheckNewAnswer(7),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(5).json)
+        inputData.commits.clear()
+      },
+      AddData(inputData, 8),
+      AdvanceManualClock(100),
+      CheckNewAnswer(8),
+      AddData(inputData, 9),
+      AdvanceManualClock(800), //  persist offset
+      CheckNewAnswer(9),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5, 6, 9))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5, 6, 9))
+
+    // simulate batch 9 doesn't have a commit log
+    new File(checkpointLocation + "/commits/9").delete()
+    new File(checkpointLocation + "/commits/.9.crc").delete()
+
+    /**
+     * Turn async progress tracking off
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 10),
+      CheckNewAnswer(7, 8, 9, 10),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(9).json)
+        inputData.commits.clear()
+      },
+      AddData(inputData, 11),
+      CheckNewAnswer(11),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(10).json)
+        inputData.commits.clear()
+      },
+      StopStream
+    )
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5, 6, 9, 10, 11))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5, 6, 9, 10, 11))
+  }
+
+  test("Fail on stateful pipelines") {
+    val rateStream = spark.readStream
+      .format("rate")
+      .option("numPartitions", 1)
+      .option("rowsPerSecond", 10)
+      .load()
+      .toDF()
+
+    val windowedStream = rateStream
+      .withWatermark("timestamp", "0 seconds")
+      .groupBy(window(column("timestamp"), "10 seconds"), column("value"))
+      .count()
+
+    val e = intercept[StreamingQueryException] {
+      val query = windowedStream.writeStream
+        .format("noop")
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+
+      query.processAllAvailable()
+    }
+    e.getMessage should include(
+      "Stateful streaming queries does not support async progress tracking at this moment."
+    )
+  }
+
+  test("Fail on pipelines using unsupported sinks") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("parquet")
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(
+          "checkpointLocation",
+          Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+        )
+        .start("/tmp")
+    }
+
+    e.getMessage should equal("Sink FileSink[/tmp] does not support async progress tracking")
+  }
+
+  test("with log purging") {
+
+    withSQLConf(SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2", SQLConf.ASYNC_LOG_PURGE.key -> "false") {
+      withTempDir { checkpointLocation =>
+        val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+        val ds = inputData.toDS()
+
+        val clock = new StreamManualClock
+        testStream(
+          ds,
+          extraOptions = Map(
+            ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+            ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+          )
+        )(
+          // need to to set processing time to something so manual clock will work
+          StartStream(
+            Trigger.ProcessingTime("1 millisecond"),
+            checkpointLocation = checkpointLocation.getCanonicalPath,
+            triggerClock = clock
+          ),
+          AddData(inputData, 0),
+          AdvanceManualClock(100),
+          CheckNewAnswer(0),
+          AddData(inputData, 1),
+          AdvanceManualClock(100),
+          CheckNewAnswer(1),
+          AddData(inputData, 2),
+          AdvanceManualClock(100),
+          CheckNewAnswer(2),
+          AddData(inputData, 3),
+          AdvanceManualClock(800), // should trigger offset commit write to durable storage
+          CheckNewAnswer(3),
+          AddData(inputData, 4),
+          AdvanceManualClock(100),
+          CheckNewAnswer(4),
+          AddData(inputData, 5),
+          AdvanceManualClock(100),
+          CheckNewAnswer(5),
+          AddData(inputData, 6),
+          AdvanceManualClock(100),
+          CheckNewAnswer(6),
+          AddData(inputData, 7),
+          AdvanceManualClock(800), // should trigger offset commit write to durable storage
+          CheckNewAnswer(7),
+          Execute { q =>
+            // wait for all async log writes to finish
+            waitPendingOffsetWrites(q)
+          },
+          // add a new row to make sure log purge has kicked in.
+          // There can be a race condition in which the commit log entry for the previous batch
+          // may or may not be written to disk yet before the log purge is called.
+          // Adding another batch here will make sure purge is called on the correct number of
+          // offset and commit log entries
+          AddData(inputData, 8),
+          AdvanceManualClock(100),
+          CheckNewAnswer(8),
+          Execute { q =>
+            waitPendingOffsetWrites(q)
+            getListOfFiles(checkpointLocation + "/offsets")
+              .filter(file => !file.isHidden)
+              .map(file => file.getName.toInt)
+              .sorted should equal(Array(3, 7))
+            getListOfFiles(checkpointLocation + "/commits")
+              .filter(file => !file.isHidden)
+              .map(file => file.getName.toInt)
+              .sorted should equal(Array(3, 7))
+            q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].writtenToDurableStorage.size() should be(2)
+            q.commitLog.asInstanceOf[AsyncCommitLog].writtenToDurableStorage.size() should be(2)
+          },
+          StopStream
+        )
+
+        /**
+         * restart
+         */
+        testStream(
+          ds,
+          extraOptions = Map(
+            ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+            ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+          )
+        )(
+          // need to to set processing time to something so manual clock will work
+          StartStream(
+            Trigger.ProcessingTime("1 millisecond"),
+            checkpointLocation = checkpointLocation.getCanonicalPath,
+            triggerClock = clock
+          ),
+          AddData(inputData, 9),
+          AdvanceManualClock(100),
+          CheckNewAnswer(8, 9),
+          AddData(inputData, 10),
+          AdvanceManualClock(100),
+          CheckNewAnswer(10),
+          AddData(inputData, 11),
+          AdvanceManualClock(100),
+          CheckNewAnswer(11),
+          AddData(inputData, 12),
+          AdvanceManualClock(800), // should trigger offset commit write to durable storage
+          CheckNewAnswer(12),
+          Execute { q =>
+            // wait for all async log writes to finish
+            waitPendingOffsetWrites(q)
+          },
+          // add a new row to make sure log purge has kicked in.
+          // There can be a race condition in which the commit log entry for the previous batch
+          // may or may not be written to disk yet before the log purge is called.
+          // Adding another batch here will make sure purge is called on the correct number of
+          // offset and commit log entries
+          AddData(inputData, 13),
+          AdvanceManualClock(100),
+          CheckNewAnswer(13),
+          Execute { q =>
+            waitPendingOffsetWrites(q)
+            getListOfFiles(checkpointLocation + "/offsets")
+              .filter(file => !file.isHidden)
+              .map(file => file.getName.toInt)
+              .sorted should equal(Array(8, 12))
+            getListOfFiles(checkpointLocation + "/commits")
+              .filter(file => !file.isHidden)
+              .map(file => file.getName.toInt)
+              .sorted should equal(Array(8, 12))
+
+            q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].writtenToDurableStorage.size() should be(2)
+            q.commitLog.asInstanceOf[AsyncCommitLog].writtenToDurableStorage.size() should be(2)
+          },
+          StopStream
+        )
+      }
+    }
+  }
+
+  test("with async log purging") {
+    withSQLConf(SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2", SQLConf.ASYNC_LOG_PURGE.key -> "true") {
+      withTempDir { checkpointLocation =>
+        val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+        val ds = inputData.toDS()
+
+        val clock = new StreamManualClock
+        testStream(
+          ds,
+          extraOptions = Map(
+            ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+            ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+          )
+        )(
+          // need to to set processing time to something so manual clock will work
+          StartStream(
+            Trigger.ProcessingTime("1 millisecond"),
+            checkpointLocation = checkpointLocation.getCanonicalPath,
+            triggerClock = clock
+          ),
+          AddData(inputData, 0),
+          AdvanceManualClock(100),
+          CheckNewAnswer(0),
+          AddData(inputData, 1),
+          AdvanceManualClock(100),
+          CheckNewAnswer(1),
+          AddData(inputData, 2),
+          AdvanceManualClock(100),
+          CheckNewAnswer(2),
+          AddData(inputData, 3),
+          AdvanceManualClock(800), // should trigger offset commit write to durable storage
+          Execute { q =>
+            // wait for async log writes to complete
+            waitPendingOffsetWrites(q)
+            eventually(timeout(Span(5, Seconds))) {

Review Comment:
   Just to make sure they are visible on disk



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049079507


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/offsets/3").delete()
+    new File(checkpointLocation + "/offsets/.3.crc").delete()
+    new File(checkpointLocation + "/offsets/4").delete()
+    new File(checkpointLocation + "/offsets/.4.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.4.crc").delete()
+    new File(checkpointLocation + "/commits/4").delete()
+    new File(checkpointLocation + "/commits/.5.crc").delete()
+    new File(checkpointLocation + "/commits/5").delete()
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      AddData(inputData2, 3),
+      AddData(inputData2, 4),
+      AddData(inputData2, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3, 4, 5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+  }
+
+  test(
+    "recovery when gaps in in offset and" +
+    " commit log"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/commits/2").delete()
+    new File(checkpointLocation + "/commits/.2.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from batch 1
+      CheckNewAnswer(1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+  }
+
+  test("recovery non-contiguous log") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // delete batch 3 from commit log
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * Turn async progress tracking off and test recovery
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 4),
+      CheckNewAnswer(1, 2, 3, 4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+  }
+
+  test("switching async progress tracking with interval commits on and off") {
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    var clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    inputData.commits.length should be(0)
+
+    /**
+     * Turn async progress tracking off
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work

Review Comment:
   will remove



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1043003950


##########
core/src/main/scala/org/apache/spark/util/ThreadUtils.scala:
##########
@@ -167,6 +167,27 @@ private[spark] object ThreadUtils {
     Executors.newFixedThreadPool(1, threadFactory).asInstanceOf[ThreadPoolExecutor]
   }
 
+  /**
+   * Wrapper over newSingleThreadExecutor that allows the specification
+   * of a RejectedExecutionHandler
+   */
+  def newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    threadName: String,

Review Comment:
   nit: 4 spaces for definition of params



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala:
##########
@@ -2014,7 +2014,6 @@ object SQLConf {
       .booleanConf
       .createWithDefault(true)
 
-

Review Comment:
   nit: unnecessary change



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala:
##########
@@ -342,17 +342,14 @@ class MicroBatchExecution(
         isCurrentBatchConstructed = true
         availableOffsets = nextOffsets.toStreamProgress(sources)
         /* Initialize committed offsets to a committed batch, which at this
-         * is the second latest batch id in the offset log. */
-        if (latestBatchId != 0) {
-          val secondLatestOffsets = offsetLog.get(latestBatchId - 1).getOrElse {
-            logError(s"The offset log for batch ${latestBatchId - 1} doesn't exist, " +
-              s"which is required to restart the query from the latest batch $latestBatchId " +
-              "from the offset log. Please ensure there are two subsequent offset logs " +
-              "available for the latest batch via manually deleting the offset file(s). " +
-              "Please also ensure the latest batch for commit log is equal or one batch " +
-              "earlier than the latest batch for offset log.")
-            throw new IllegalStateException(s"batch ${latestBatchId - 1} doesn't exist")
-          }
+         * is the second latest batch id in the offset log.
+         * The offset log may not be contiguous */
+        val prevBatchId = offsetLog.getPrevBatchFromStorage(latestBatchId)
+        if (latestBatchId != 0 && prevBatchId.isDefined) {
+            val secondLatestOffsets = offsetLog.get(prevBatchId.get).getOrElse({

Review Comment:
   nit: indentation looks to be off



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala:
##########
@@ -148,6 +148,24 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path:
     }
   }
 
+  /**
+   * Get the id of the previous batch from storage

Review Comment:
   nit: We don't require to fill all the form with meaningless info. Please remove parts if you don't feel it's helpful or you're just reiterating just to fill out.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala:
##########
@@ -342,17 +342,14 @@ class MicroBatchExecution(
         isCurrentBatchConstructed = true
         availableOffsets = nextOffsets.toStreamProgress(sources)
         /* Initialize committed offsets to a committed batch, which at this
-         * is the second latest batch id in the offset log. */
-        if (latestBatchId != 0) {
-          val secondLatestOffsets = offsetLog.get(latestBatchId - 1).getOrElse {
-            logError(s"The offset log for batch ${latestBatchId - 1} doesn't exist, " +
-              s"which is required to restart the query from the latest batch $latestBatchId " +
-              "from the offset log. Please ensure there are two subsequent offset logs " +
-              "available for the latest batch via manually deleting the offset file(s). " +
-              "Please also ensure the latest batch for commit log is equal or one batch " +
-              "earlier than the latest batch for offset log.")
-            throw new IllegalStateException(s"batch ${latestBatchId - 1} doesn't exist")
-          }
+         * is the second latest batch id in the offset log.

Review Comment:
   I'm actually in favor of limiting the change to async progress tracking one, like we are adding protected methods for extension.
   
   Do we have a goal to support smooth transition between normal microbatch execution and async progress tracking for a single query? If we want to do so, we should have a clear explanation on the semantic and behavior during transition between two (both directions). Otherwise, I'd rather say let's leave normal microbatch execution work for the same.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala:
##########
@@ -727,18 +719,56 @@ class MicroBatchExecution(
 
     withProgressLocked {
       sinkCommitProgress = batchSinkProgress
-      watermarkTracker.updateWatermark(lastExecution.executedPlan)
-      reportTimeTaken("commitOffsets") {
-        assert(commitLog.add(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark)),
-          "Concurrent update to the commit log. Multiple streaming jobs detected for " +
-            s"$currentBatchId")
-      }
-      committedOffsets ++= availableOffsets
+      markMicroBatchEnd()
     }
     logDebug(s"Completed batch ${currentBatchId}")
   }
 
-  /** Execute a function while locking the stream from making an progress */
+
+  /**
+   * Called at the start of the micro batch with given offsets. It takes care of offset
+   * checkpointing to offset log and any microbatch startup tasks.
+   */
+  protected def markMicroBatchStart(): Unit = {
+    assert(offsetLog.add(currentBatchId,
+      availableOffsets.toOffsetSeq(sources, offsetSeqMetadata)),
+      s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId")
+    logInfo(s"Committed offsets for batch $currentBatchId. " +
+      s"Metadata ${offsetSeqMetadata.toString}")
+  }
+
+  /**
+   * Method called once after the planning is done and before the start of the microbatch execution.
+   * It can be used to perform any pre-execution tasks.
+   */
+  protected def markMicroBatchExecutionStart(): Unit = {}
+
+  /**
+   * Called after the microbatch has completed execution. It takes care of committing the offset
+   * to commit log and other bookkeeping.
+   */
+  protected def markMicroBatchEnd(): Unit = {
+    watermarkTracker.updateWatermark(lastExecution.executedPlan)
+    reportTimeTaken("commitOffsets") {
+      assert(commitLog.add(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark)),
+        "Concurrent update to the commit log. Multiple streaming jobs detected for " +
+          s"$currentBatchId")
+    }
+    committedOffsets ++= availableOffsets
+  }
+
+  protected def cleanUpLastExecutedMicroBatch(): Unit = {
+    if (currentBatchId != 0) {
+      val prevBatchOff = offsetLog.get(currentBatchId - 1)
+      if (prevBatchOff.isDefined) {
+        commitSources(prevBatchOff.get)
+      } else {
+        throw new IllegalStateException(s"batch ${currentBatchId - 1} doesn't exist")
+      }
+    }
+  }
+
+    /** Execute a function while locking the stream from making an progress */

Review Comment:
   nit: indentation



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())
+            if (prevBatchOff.isDefined) {
+              // Offset is ready to be committed by the source. Add to queue
+              sourceCommitQueue.add(prevBatchOff.get)
+            } else {
+              throw new IllegalStateException(
+                s"batch ${lastBatchPersistedToDurableStorage.get()} doesn't exist"
+              )
+            }
+          }
+          lastBatchPersistedToDurableStorage.set(batchId)
+        }
+      })
+      .exceptionally((th: Throwable) => {
+        logError("Encountered error while performing async offset write", th)
+        errorNotifier.markError(th)
+        return
+      })
+
+    // check if there are offsets that are ready to be committed by the source
+    var offset = sourceCommitQueue.poll()
+    while (offset != null) {
+      commitSources(offset)
+      offset = sourceCommitQueue.poll()
+    }
+  }
+
+  override def markMicroBatchEnd(): Unit = {
+    watermarkTracker.updateWatermark(lastExecution.executedPlan)
+    reportTimeTaken("commitOffsets") {
+      // check if current batch there is a async write for the offset log is issued for this batch
+      // if so, we should do the same for commit log
+      if (!offsetLog.getAsyncOffsetWrite(currentBatchId).isEmpty) {
+        commitLog
+          .addAsync(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))
+          .exceptionally((th: Throwable) => {
+            logError("Got exception during async write", th)
+            errorNotifier.markError(th)
+            return
+          })
+      } else {
+        if (!commitLog.addInMemory(
+          currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))) {
+          throw new IllegalStateException(
+            s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId"
+          )
+        }
+      }
+      offsetLog.removeAsyncOffsetWrite(currentBatchId)
+    }
+    committedOffsets ++= availableOffsets
+  }
+
+  // need to look at the number of files on disk
+  override def purge(threshold: Long): Unit = {
+    while (offsetLog.writtenToDurableStorage.size() > minLogEntriesToMaintain) {
+      offsetLog.writtenToDurableStorage.poll()
+    }
+    offsetLog.purge(offsetLog.writtenToDurableStorage.peek())
+
+    while (commitLog.writtenToDurableStorage.size() > minLogEntriesToMaintain) {
+      commitLog.writtenToDurableStorage.poll()
+    }
+    commitLog.purge(commitLog.writtenToDurableStorage.peek())
+  }
+
+  override def cleanup(): Unit = {
+    super.cleanup()
+
+    ThreadUtils.shutdown(asyncWritesExecutorService)
+    logInfo(s"Async progress tracking executor pool for query ${prettyIdString} has been shutdown")
+  }
+
+  // used for testing
+  def areWritesPendingOrInProgress(): Boolean = {
+    asyncWritesExecutorService.getQueue.size() > 0 || asyncWritesExecutorService.getActiveCount > 0
+  }
+
+  private def validateAndGetTrigger(): TriggerExecutor = {
+    // validate that the pipeline is using a supported sink
+    if (!extraOptions
+      .get(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK

Review Comment:
   nit: why not import `AsyncProgressTrackingMicroBatchExecution._` in the early of class definition and remove redundant AsyncProgressTrackingMicroBatchExecution?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())
+            if (prevBatchOff.isDefined) {
+              // Offset is ready to be committed by the source. Add to queue
+              sourceCommitQueue.add(prevBatchOff.get)
+            } else {
+              throw new IllegalStateException(
+                s"batch ${lastBatchPersistedToDurableStorage.get()} doesn't exist"
+              )
+            }
+          }
+          lastBatchPersistedToDurableStorage.set(batchId)
+        }
+      })
+      .exceptionally((th: Throwable) => {
+        logError("Encountered error while performing async offset write", th)
+        errorNotifier.markError(th)
+        return
+      })
+
+    // check if there are offsets that are ready to be committed by the source
+    var offset = sourceCommitQueue.poll()
+    while (offset != null) {
+      commitSources(offset)
+      offset = sourceCommitQueue.poll()
+    }
+  }
+
+  override def markMicroBatchEnd(): Unit = {
+    watermarkTracker.updateWatermark(lastExecution.executedPlan)
+    reportTimeTaken("commitOffsets") {
+      // check if current batch there is a async write for the offset log is issued for this batch
+      // if so, we should do the same for commit log
+      if (!offsetLog.getAsyncOffsetWrite(currentBatchId).isEmpty) {
+        commitLog
+          .addAsync(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))
+          .exceptionally((th: Throwable) => {
+            logError("Got exception during async write", th)
+            errorNotifier.markError(th)
+            return
+          })
+      } else {
+        if (!commitLog.addInMemory(
+          currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))) {
+          throw new IllegalStateException(
+            s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId"
+          )
+        }
+      }
+      offsetLog.removeAsyncOffsetWrite(currentBatchId)
+    }
+    committedOffsets ++= availableOffsets
+  }
+
+  // need to look at the number of files on disk
+  override def purge(threshold: Long): Unit = {
+    while (offsetLog.writtenToDurableStorage.size() > minLogEntriesToMaintain) {
+      offsetLog.writtenToDurableStorage.poll()
+    }
+    offsetLog.purge(offsetLog.writtenToDurableStorage.peek())
+
+    while (commitLog.writtenToDurableStorage.size() > minLogEntriesToMaintain) {
+      commitLog.writtenToDurableStorage.poll()
+    }
+    commitLog.purge(commitLog.writtenToDurableStorage.peek())
+  }
+
+  override def cleanup(): Unit = {
+    super.cleanup()
+
+    ThreadUtils.shutdown(asyncWritesExecutorService)
+    logInfo(s"Async progress tracking executor pool for query ${prettyIdString} has been shutdown")
+  }
+
+  // used for testing
+  def areWritesPendingOrInProgress(): Boolean = {
+    asyncWritesExecutorService.getQueue.size() > 0 || asyncWritesExecutorService.getActiveCount > 0
+  }
+
+  private def validateAndGetTrigger(): TriggerExecutor = {
+    // validate that the pipeline is using a supported sink
+    if (!extraOptions
+      .get(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK
+      )
+      .getOrElse("false")
+      .toBoolean) {
+      try {
+        plan.sink.name() match {
+          case "noop-table" =>

Review Comment:
   Let's not forget to describe which sink we support in the guide doc.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala:
##########
@@ -300,10 +322,8 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path:
     }
   }
 
-
   /**
-   * List the available batches on file system. As a workaround for S3 inconsistent list, it also
-   * tries to take `batchCache` into consideration to infer a better answer.
+   * List the available batches on file system

Review Comment:
   nit: Let's use one-liner comment `/** List the available batches on file system. */`



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())
+            if (prevBatchOff.isDefined) {
+              // Offset is ready to be committed by the source. Add to queue
+              sourceCommitQueue.add(prevBatchOff.get)
+            } else {
+              throw new IllegalStateException(
+                s"batch ${lastBatchPersistedToDurableStorage.get()} doesn't exist"
+              )
+            }
+          }
+          lastBatchPersistedToDurableStorage.set(batchId)
+        }
+      })
+      .exceptionally((th: Throwable) => {
+        logError("Encountered error while performing async offset write", th)
+        errorNotifier.markError(th)
+        return
+      })
+
+    // check if there are offsets that are ready to be committed by the source
+    var offset = sourceCommitQueue.poll()
+    while (offset != null) {
+      commitSources(offset)
+      offset = sourceCommitQueue.poll()
+    }
+  }
+
+  override def markMicroBatchEnd(): Unit = {
+    watermarkTracker.updateWatermark(lastExecution.executedPlan)
+    reportTimeTaken("commitOffsets") {
+      // check if current batch there is a async write for the offset log is issued for this batch
+      // if so, we should do the same for commit log
+      if (!offsetLog.getAsyncOffsetWrite(currentBatchId).isEmpty) {
+        commitLog
+          .addAsync(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))
+          .exceptionally((th: Throwable) => {
+            logError("Got exception during async write", th)
+            errorNotifier.markError(th)
+            return
+          })
+      } else {
+        if (!commitLog.addInMemory(
+          currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))) {
+          throw new IllegalStateException(
+            s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId"
+          )
+        }
+      }
+      offsetLog.removeAsyncOffsetWrite(currentBatchId)
+    }
+    committedOffsets ++= availableOffsets
+  }
+
+  // need to look at the number of files on disk
+  override def purge(threshold: Long): Unit = {
+    while (offsetLog.writtenToDurableStorage.size() > minLogEntriesToMaintain) {
+      offsetLog.writtenToDurableStorage.poll()
+    }
+    offsetLog.purge(offsetLog.writtenToDurableStorage.peek())
+
+    while (commitLog.writtenToDurableStorage.size() > minLogEntriesToMaintain) {
+      commitLog.writtenToDurableStorage.poll()
+    }
+    commitLog.purge(commitLog.writtenToDurableStorage.peek())
+  }
+
+  override def cleanup(): Unit = {
+    super.cleanup()
+
+    ThreadUtils.shutdown(asyncWritesExecutorService)
+    logInfo(s"Async progress tracking executor pool for query ${prettyIdString} has been shutdown")
+  }
+
+  // used for testing
+  def areWritesPendingOrInProgress(): Boolean = {
+    asyncWritesExecutorService.getQueue.size() > 0 || asyncWritesExecutorService.getActiveCount > 0
+  }
+
+  private def validateAndGetTrigger(): TriggerExecutor = {
+    // validate that the pipeline is using a supported sink
+    if (!extraOptions
+      .get(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK
+      )
+      .getOrElse("false")
+      .toBoolean) {
+      try {
+        plan.sink.name() match {
+          case "noop-table" =>
+          case "console" =>
+          case "MemorySink" =>
+          case "KafkaTable" =>
+          case _ =>
+            throw new IllegalArgumentException(
+              s"Sink ${plan.sink.name()}" +
+                s" does not support async progress tracking"
+            )
+        }
+      } catch {
+        case e: IllegalStateException =>
+          // sink does not implement name() method
+          if (e.getMessage.equals("should not be called.")) {

Review Comment:
   Is it really necessary to check the message with string literal? Consider the case when we would get IllegalStateException - name() should throw it, meaning we have no way to get the name in any way.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())
+            if (prevBatchOff.isDefined) {
+              // Offset is ready to be committed by the source. Add to queue
+              sourceCommitQueue.add(prevBatchOff.get)
+            } else {
+              throw new IllegalStateException(
+                s"batch ${lastBatchPersistedToDurableStorage.get()} doesn't exist"
+              )
+            }
+          }
+          lastBatchPersistedToDurableStorage.set(batchId)
+        }
+      })
+      .exceptionally((th: Throwable) => {
+        logError("Encountered error while performing async offset write", th)
+        errorNotifier.markError(th)
+        return
+      })
+
+    // check if there are offsets that are ready to be committed by the source
+    var offset = sourceCommitQueue.poll()
+    while (offset != null) {
+      commitSources(offset)
+      offset = sourceCommitQueue.poll()
+    }
+  }
+
+  override def markMicroBatchEnd(): Unit = {
+    watermarkTracker.updateWatermark(lastExecution.executedPlan)
+    reportTimeTaken("commitOffsets") {
+      // check if current batch there is a async write for the offset log is issued for this batch
+      // if so, we should do the same for commit log
+      if (!offsetLog.getAsyncOffsetWrite(currentBatchId).isEmpty) {
+        commitLog
+          .addAsync(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))
+          .exceptionally((th: Throwable) => {
+            logError("Got exception during async write", th)
+            errorNotifier.markError(th)
+            return
+          })
+      } else {
+        if (!commitLog.addInMemory(
+          currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))) {
+          throw new IllegalStateException(
+            s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId"
+          )
+        }
+      }
+      offsetLog.removeAsyncOffsetWrite(currentBatchId)
+    }
+    committedOffsets ++= availableOffsets
+  }
+
+  // need to look at the number of files on disk
+  override def purge(threshold: Long): Unit = {
+    while (offsetLog.writtenToDurableStorage.size() > minLogEntriesToMaintain) {
+      offsetLog.writtenToDurableStorage.poll()
+    }
+    offsetLog.purge(offsetLog.writtenToDurableStorage.peek())
+
+    while (commitLog.writtenToDurableStorage.size() > minLogEntriesToMaintain) {
+      commitLog.writtenToDurableStorage.poll()
+    }
+    commitLog.purge(commitLog.writtenToDurableStorage.peek())
+  }
+
+  override def cleanup(): Unit = {
+    super.cleanup()
+
+    ThreadUtils.shutdown(asyncWritesExecutorService)
+    logInfo(s"Async progress tracking executor pool for query ${prettyIdString} has been shutdown")
+  }
+
+  // used for testing
+  def areWritesPendingOrInProgress(): Boolean = {
+    asyncWritesExecutorService.getQueue.size() > 0 || asyncWritesExecutorService.getActiveCount > 0
+  }
+
+  private def validateAndGetTrigger(): TriggerExecutor = {
+    // validate that the pipeline is using a supported sink
+    if (!extraOptions
+      .get(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK
+      )
+      .getOrElse("false")
+      .toBoolean) {
+      try {
+        plan.sink.name() match {
+          case "noop-table" =>
+          case "console" =>
+          case "MemorySink" =>
+          case "KafkaTable" =>
+          case _ =>
+            throw new IllegalArgumentException(
+              s"Sink ${plan.sink.name()}" +
+                s" does not support async progress tracking"
+            )
+        }
+      } catch {
+        case e: IllegalStateException =>
+          // sink does not implement name() method
+          if (e.getMessage.equals("should not be called.")) {
+            throw new IllegalArgumentException(
+              s"Sink ${plan.sink}" +
+                s" does not support async progress tracking"
+            )
+          } else {
+            throw e
+          }
+      }
+    }
+
+    trigger match {
+      case t: ProcessingTimeTrigger => ProcessingTimeExecutor(t, triggerClock)
+      case OneTimeTrigger =>
+        throw new IllegalArgumentException(
+          "Async progress tracking cannot be used with Once trigger")
+      case AvailableNowTrigger =>
+        throw new IllegalArgumentException(
+          "Async progress tracking cannot be used with AvailableNow trigger"
+        )
+      case _ => throw new IllegalStateException(s"Unknown type of trigger: $trigger")
+    }
+  }
+
+  private def checkNotStatefulPipeline: Unit = {
+    if (isFirstBatch) {
+      lastExecution.executedPlan.collect {
+        case p if p.isInstanceOf[StateStoreWriter] =>
+          throw new IllegalArgumentException(
+            "Stateful streaming queries does not support async progress tracking at this moment."
+          )
+          isFirstBatch = false

Review Comment:
   This is not reachable.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052692275


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/offsets/3").delete()
+    new File(checkpointLocation + "/offsets/.3.crc").delete()
+    new File(checkpointLocation + "/offsets/4").delete()
+    new File(checkpointLocation + "/offsets/.4.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.4.crc").delete()
+    new File(checkpointLocation + "/commits/4").delete()
+    new File(checkpointLocation + "/commits/.5.crc").delete()
+    new File(checkpointLocation + "/commits/5").delete()
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      AddData(inputData2, 3),
+      AddData(inputData2, 4),
+      AddData(inputData2, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3, 4, 5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+  }
+
+  test(
+    "recovery when gaps in in offset and" +
+    " commit log"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/commits/2").delete()
+    new File(checkpointLocation + "/commits/.2.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from batch 1

Review Comment:
   This is for supporting turning async progress tracking on and off



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR closed pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR closed pull request #38517: [SPARK-39591][SS] Async Progress Tracking
URL: https://github.com/apache/spark/pull/38517


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052712497


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1765 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.File
+import java.util.concurrent.{CountDownLatch, Semaphore, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.TestUtils
+import org.apache.spark.sql._
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.Utils
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+  extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 21 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 21 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  def testAsyncWriteErrorsAlreadyExists(path: String): Unit = {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + path))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            TestUtils.assertExceptionMsg(e,
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 1:" +
+    " offset file already exists for a batch") {
+    testAsyncWriteErrorsAlreadyExists("/offsets/1")
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1:" +
+    " commit file already exists for a batch") {
+    testAsyncWriteErrorsAlreadyExists("/commits/1")
+  }
+
+  def testAsyncWriteErrorsPermissionsIssue(path: String): Unit = {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + path)
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+    // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2:" +
+    " cannot write offset files due to permissions issue") {
+    testAsyncWriteErrorsPermissionsIssue("/offsets")
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2" +
+    ": commit file already exists for a batch") {
+    testAsyncWriteErrorsPermissionsIssue("/commits")
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+    extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach{ thread =>
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    }
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test("recovery when first offset is not zero and not commit log entries") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only
+    // contains batch 0, 2, 5 and commit log only contain 0, 2
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/offsets/3").delete()
+    new File(checkpointLocation + "/offsets/.3.crc").delete()
+    new File(checkpointLocation + "/offsets/4").delete()
+    new File(checkpointLocation + "/offsets/.4.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.4.crc").delete()
+    new File(checkpointLocation + "/commits/4").delete()
+    new File(checkpointLocation + "/commits/.5.crc").delete()
+    new File(checkpointLocation + "/commits/5").delete()
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      AddData(inputData2, 3),
+      AddData(inputData2, 4),
+      AddData(inputData2, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // since
+      CheckNewAnswer(3, 4, 5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+  }
+
+  test("recovery when gaps in in offset and commit log") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only
+    // contains batch 0, 2 and commit log only contains 9
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/commits/2").delete()
+    new File(checkpointLocation + "/commits/.2.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from batch 1
+      CheckNewAnswer(1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+  }
+
+  test("recovery non-contiguous log") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // delete batch 3 from commit log
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * Turn async progress tracking off and test recovery
+     */
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 4),
+      CheckNewAnswer(1, 2, 3, 4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+    // batches 0 and 3 should be logged

Review Comment:
   will update



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on PR #38517:
URL: https://github.com/apache/spark/pull/38517#issuecomment-1358828963

   @HeartSaVioR I have addressed your comments please take another look


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052719778


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())

Review Comment:
   Yup will resolve. Thanks for making this clear.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052722369


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala:
##########
@@ -342,17 +342,14 @@ class MicroBatchExecution(
         isCurrentBatchConstructed = true
         availableOffsets = nextOffsets.toStreamProgress(sources)
         /* Initialize committed offsets to a committed batch, which at this
-         * is the second latest batch id in the offset log. */
-        if (latestBatchId != 0) {
-          val secondLatestOffsets = offsetLog.get(latestBatchId - 1).getOrElse {
-            logError(s"The offset log for batch ${latestBatchId - 1} doesn't exist, " +
-              s"which is required to restart the query from the latest batch $latestBatchId " +
-              "from the offset log. Please ensure there are two subsequent offset logs " +
-              "available for the latest batch via manually deleting the offset file(s). " +
-              "Please also ensure the latest batch for commit log is equal or one batch " +
-              "earlier than the latest batch for offset log.")
-            throw new IllegalStateException(s"batch ${latestBatchId - 1} doesn't exist")
-          }
+         * is the second latest batch id in the offset log.

Review Comment:
   How often we encounter this is not important. The fact we are removing guard to not break fault-tolerance semantic is important.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052639318


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)

Review Comment:
   ok



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052672070


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()

Review Comment:
   oh I see what you are saying



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049255090


##########
connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala:
##########
@@ -195,6 +200,102 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase {
     true
   }
 
+  /**
+   * Test async progress tracking capability with Kafka source and sink
+   */
+  test("async progress tracking") {
+    val inputTopic = newTopic()
+    testUtils.createTopic(inputTopic, partitions = 5)
+
+    val dataSent = new ListBuffer[String]()
+    testUtils.sendMessages(inputTopic, (0 until 15).map { case x =>
+      val m = s"foo-$x"
+      dataSent += m
+      m
+    }.toArray, Some(0))
+
+    val outputTopic = newTopic()
+    testUtils.createTopic(outputTopic, partitions = 5)
+
+    withTempDir { dir =>
+      val reader = spark
+        .readStream
+        .format("kafka")
+        .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+        .option("kafka.metadata.max.age.ms", "1")
+        .option("maxOffsetsPerTrigger", 5)
+        .option("subscribe", inputTopic)
+        .option("startingOffsets", "earliest")
+        .load()
+
+      def startQuery(): StreamingQuery = {
+        reader.writeStream
+          .format("kafka")
+          .option("checkpointLocation", dir.getCanonicalPath)
+          .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+          .option("kafka.max.block.ms", "5000")
+          .option("topic", outputTopic)
+          .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+          .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+          .queryName("kafkaStream")
+          .start()
+      }
+
+      def readResults(): ListBuffer[String] = {

Review Comment:
   ok



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048938757


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true

Review Comment:
   Let me just add a comment



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049074121


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty

Review Comment:
   will correct comment



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049070447


##########
core/src/main/scala/org/apache/spark/util/ThreadUtils.scala:
##########
@@ -167,6 +167,27 @@ private[spark] object ThreadUtils {
     Executors.newFixedThreadPool(1, threadFactory).asInstanceOf[ThreadPoolExecutor]
   }
 
+  /**
+   * Wrapper over newSingleThreadExecutor that allows the specification
+   * of a RejectedExecutionHandler
+   */
+  def newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    threadName: String,

Review Comment:
   Not really. Please note that params in method call and params in method definition have different indentation. 2 spaces for former, 4 spaces for latter.
   https://github.com/databricks/scala-style-guide#indent
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048943509


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())

Review Comment:
   We don't need to check the commit log.  An entry in the offset log for n indicates that batch n-1 / previous batch has committed successfully to the commit log



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049093235


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/offsets/3").delete()
+    new File(checkpointLocation + "/offsets/.3.crc").delete()
+    new File(checkpointLocation + "/offsets/4").delete()
+    new File(checkpointLocation + "/offsets/.4.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.4.crc").delete()
+    new File(checkpointLocation + "/commits/4").delete()
+    new File(checkpointLocation + "/commits/.5.crc").delete()
+    new File(checkpointLocation + "/commits/5").delete()
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      AddData(inputData2, 3),
+      AddData(inputData2, 4),
+      AddData(inputData2, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3, 4, 5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+  }
+
+  test(
+    "recovery when gaps in in offset and" +
+    " commit log"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/commits/2").delete()
+    new File(checkpointLocation + "/commits/.2.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from batch 1
+      CheckNewAnswer(1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+  }
+
+  test("recovery non-contiguous log") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // delete batch 3 from commit log
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * Turn async progress tracking off and test recovery
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 4),
+      CheckNewAnswer(1, 2, 3, 4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+  }
+
+  test("switching async progress tracking with interval commits on and off") {
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    var clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    inputData.commits.length should be(0)
+
+    /**
+     * Turn async progress tracking off
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(3).json)
+        inputData.commits.clear()
+      },
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(4).json)
+        inputData.commits.clear()
+      },
+      StopStream
+    )
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+    /**
+     * Turn async progress tracking on
+     */
+    clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 6), // will get persisted since first batch on restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      AddData(inputData, 7),
+      AdvanceManualClock(100),
+      CheckNewAnswer(7),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(5).json)
+        inputData.commits.clear()
+      },
+      AddData(inputData, 8),
+      AdvanceManualClock(100),
+      CheckNewAnswer(8),
+      AddData(inputData, 9),
+      AdvanceManualClock(800), //  persist offset
+      CheckNewAnswer(9),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5, 6, 9))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5, 6, 9))
+
+    // simulate batch 9 doesn't have a commit log
+    new File(checkpointLocation + "/commits/9").delete()
+    new File(checkpointLocation + "/commits/.9.crc").delete()
+
+    /**
+     * Turn async progress tracking off
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work

Review Comment:
   will remove



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048961319


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala:
##########
@@ -148,6 +148,24 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path:
     }
   }
 
+  /**
+   * Get the id of the previous batch from storage

Review Comment:
   ok



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049069404


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),

Review Comment:
   The expected behavior is that it will replay all previous data up to and including batch 2.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052937189


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -68,11 +64,42 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
     }
   }
 
+  class MemoryStreamCapture[A: Encoder](
+                                         id: Int,

Review Comment:
   nit: 4 spaces



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052964642


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -78,7 +72,9 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
   test("async WAL commits happy path") {
     val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
 
-    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+//    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)

Review Comment:
   already done



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -283,3 +264,23 @@ class AsyncProgressTrackingMicroBatchExecution(
     }
   }
 }
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+                                                               extraOptions: Map[String, String]): Long = {

Review Comment:
   already done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052535800


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,285 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+  sparkSession: SparkSession,
+  trigger: Trigger,
+  triggerClock: Clock,
+  extraOptions: Map[String, String],
+  plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  import AsyncProgressTrackingMicroBatchExecution._
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  // used to check during the first batch if the pipeline is stateful
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if streamign query is stateful

Review Comment:
   will fix



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052536573


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {

Review Comment:
   will change



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052941322


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/offsets/3").delete()
+    new File(checkpointLocation + "/offsets/.3.crc").delete()
+    new File(checkpointLocation + "/offsets/4").delete()
+    new File(checkpointLocation + "/offsets/.4.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.4.crc").delete()
+    new File(checkpointLocation + "/commits/4").delete()
+    new File(checkpointLocation + "/commits/.5.crc").delete()
+    new File(checkpointLocation + "/commits/5").delete()
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      AddData(inputData2, 3),
+      AddData(inputData2, 4),
+      AddData(inputData2, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3, 4, 5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+  }
+
+  test(
+    "recovery when gaps in in offset and" +
+    " commit log"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/commits/2").delete()
+    new File(checkpointLocation + "/commits/.2.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from batch 1
+      CheckNewAnswer(1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+  }
+
+  test("recovery non-contiguous log") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // delete batch 3 from commit log
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * Turn async progress tracking off and test recovery
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 4),
+      CheckNewAnswer(1, 2, 3, 4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+  }
+
+  test("switching async progress tracking with interval commits on and off") {
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    var clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    inputData.commits.length should be(0)
+
+    /**
+     * Turn async progress tracking off
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(3).json)
+        inputData.commits.clear()
+      },
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(4).json)
+        inputData.commits.clear()
+      },
+      StopStream
+    )
+    // batches 0 and 3 should be logged

Review Comment:
   nit: it didn't seem to be removed. (while we are here...)



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -1762,4 +1344,173 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
       }
     }
   }
+
+  test("test gaps in offset log") {
+    val inputData = MemoryStream[Int]
+    val streamEvent = inputData.toDF().select("value")
+
+    val resourceUri = this.getClass.getResource(
+      "/structured-streaming/checkpoint-test-offsetId-commitId-inconsistent/").toURI
+    val checkpointDir = Utils.createTempDir().getCanonicalFile
+    // Copy the checkpoint to a temp dir to prevent changes to the original.
+    // Not doing this will lead to the test passing on the first run, but fail subsequent runs.
+    FileUtils.copyDirectory(new File(resourceUri), checkpointDir)
+
+    // Not doing this will lead to the test passing on the first run, but fail subsequent runs.

Review Comment:
   nit: looks like redundant



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -1762,4 +1344,173 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
       }
     }
   }
+
+  test("test gaps in offset log") {
+    val inputData = MemoryStream[Int]
+    val streamEvent = inputData.toDF().select("value")
+
+    val resourceUri = this.getClass.getResource(
+      "/structured-streaming/checkpoint-test-offsetId-commitId-inconsistent/").toURI
+    val checkpointDir = Utils.createTempDir().getCanonicalFile
+    // Copy the checkpoint to a temp dir to prevent changes to the original.
+    // Not doing this will lead to the test passing on the first run, but fail subsequent runs.
+    FileUtils.copyDirectory(new File(resourceUri), checkpointDir)
+
+    // Not doing this will lead to the test passing on the first run, but fail subsequent runs.
+    FileUtils.copyDirectory(new File(resourceUri), checkpointDir)
+
+    testStream(streamEvent, extraOptions = Map(
+      ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+      ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+    ))(
+      AddData(inputData, 0),
+      AddData(inputData, 1),
+      AddData(inputData, 2),
+      AddData(inputData, 3),
+      AddData(inputData, 4),
+      StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+      CheckAnswer(3, 4)
+    )
+
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only
+    // contains batch 0, 2, 5 and commit log only contain 0, 2
+    testStream(ds, extraOptions = Map(
+      ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+      ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+    ))(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    // delete all offset files except for batch 0, 2, 5
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filterNot(f => f.getName.startsWith("0")
+        || f.getName.startsWith("2")
+        || f.getName.startsWith("5"))
+      .foreach(_.delete())
+
+    // delete all commit log files except for batch 0, 2
+    getListOfFiles(checkpointLocation + "/commits")
+      .filterNot(f => f.getName.startsWith("0") || f.getName.startsWith("2"))
+      .foreach(_.delete())
+
+    getBatchIdsSortedFromLog(checkpointLocation + "/offsets") should equal(Array(0, 2, 5))
+    getBatchIdsSortedFromLog(checkpointLocation + "/commits") should equal(Array(0, 2))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2, extraOptions = Map(
+      ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+      ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+    ))(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      AddData(inputData2, 3),
+      AddData(inputData2, 4),
+      AddData(inputData2, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // since the offset log contains batches 0, 2, 5 and the commit log contains
+      // batches 0, 2.  This indicates that batch we have successfully processed up to batch 2.
+      // Thus the data we need to process / re-process is batches 3, 4, 5
+      CheckNewAnswer(3, 4, 5),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        eventually(timeout(Span(5, Seconds))) {
+          getBatchIdsSortedFromLog(checkpointLocation + "/offsets") should equal(Array(0, 2, 5))
+          getBatchIdsSortedFromLog(checkpointLocation + "/commits") should equal(Array(0, 2, 5))
+        }
+      },
+      StopStream
+    )
+
+    getBatchIdsSortedFromLog(checkpointLocation + "/offsets") should equal(Array(0, 2, 5))
+    getBatchIdsSortedFromLog(checkpointLocation + "/commits") should equal(Array(0, 2, 5))
+  }
+
+  test("recovery when gaps exist in offset and commit log") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only
+    // contains batch 0, 2 and commit log only contains 9

Review Comment:
   nit: commit log only contains 0



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049021653


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")

Review Comment:
   This is unused code. I will remove



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049018991


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage

Review Comment:
   yes



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on PR #38517:
URL: https://github.com/apache/spark/pull/38517#issuecomment-1352612404

   @HeartSaVioR thanks for the detailed review.  I think I have address all of your comments. PTAL


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048792294


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncCommitLog.scala:
##########
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.OutputStream
+import java.util.concurrent.{CompletableFuture, ConcurrentLinkedDeque, ThreadPoolExecutor}
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.SparkSession
+
+/**
+ * Implementation of CommitLog to perform asynchronous writes to storage
+ */
+class AsyncCommitLog(sparkSession: SparkSession, path: String, executorService: ThreadPoolExecutor)
+    extends CommitLog(sparkSession, path) {

Review Comment:
   OK



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049016528


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")

Review Comment:
   not really going to save any lines of code.  This method is called for "/offsets" and "/commits"



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049015857


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {

Review Comment:
   There are tests for that as well



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049083624


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())
+            if (prevBatchOff.isDefined) {
+              // Offset is ready to be committed by the source. Add to queue
+              sourceCommitQueue.add(prevBatchOff.get)
+            } else {
+              throw new IllegalStateException(
+                s"batch ${lastBatchPersistedToDurableStorage.get()} doesn't exist"
+              )
+            }
+          }
+          lastBatchPersistedToDurableStorage.set(batchId)
+        }
+      })
+      .exceptionally((th: Throwable) => {
+        logError("Encountered error while performing async offset write", th)
+        errorNotifier.markError(th)
+        return
+      })
+
+    // check if there are offsets that are ready to be committed by the source
+    var offset = sourceCommitQueue.poll()
+    while (offset != null) {
+      commitSources(offset)
+      offset = sourceCommitQueue.poll()
+    }
+  }
+
+  override def markMicroBatchEnd(): Unit = {
+    watermarkTracker.updateWatermark(lastExecution.executedPlan)
+    reportTimeTaken("commitOffsets") {
+      // check if current batch there is a async write for the offset log is issued for this batch
+      // if so, we should do the same for commit log
+      if (!offsetLog.getAsyncOffsetWrite(currentBatchId).isEmpty) {
+        commitLog
+          .addAsync(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))
+          .exceptionally((th: Throwable) => {
+            logError("Got exception during async write", th)
+            errorNotifier.markError(th)
+            return
+          })
+      } else {
+        if (!commitLog.addInMemory(
+          currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))) {
+          throw new IllegalStateException(
+            s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId"

Review Comment:
   That's feasible for current MicroBatchExecution since multiple queries compete with specific path in external storage. The batch we add in memory only won't trigger this. We even don't share the memory cache for commit log across streaming queries in the driver.
   
   But I wouldn't mind to leave this as it is. We can just consider this as respecting previous contract.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049095149


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/offsets/3").delete()
+    new File(checkpointLocation + "/offsets/.3.crc").delete()
+    new File(checkpointLocation + "/offsets/4").delete()
+    new File(checkpointLocation + "/offsets/.4.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.4.crc").delete()
+    new File(checkpointLocation + "/commits/4").delete()
+    new File(checkpointLocation + "/commits/.5.crc").delete()
+    new File(checkpointLocation + "/commits/5").delete()
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      AddData(inputData2, 3),
+      AddData(inputData2, 4),
+      AddData(inputData2, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3, 4, 5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+  }
+
+  test(
+    "recovery when gaps in in offset and" +
+    " commit log"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/commits/2").delete()
+    new File(checkpointLocation + "/commits/.2.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from batch 1
+      CheckNewAnswer(1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+  }
+
+  test("recovery non-contiguous log") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // delete batch 3 from commit log
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * Turn async progress tracking off and test recovery
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 4),
+      CheckNewAnswer(1, 2, 3, 4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+  }
+
+  test("switching async progress tracking with interval commits on and off") {
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    var clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    inputData.commits.length should be(0)
+
+    /**
+     * Turn async progress tracking off
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(3).json)
+        inputData.commits.clear()
+      },
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(4).json)
+        inputData.commits.clear()
+      },
+      StopStream
+    )
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+    /**
+     * Turn async progress tracking on
+     */
+    clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 6), // will get persisted since first batch on restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      AddData(inputData, 7),
+      AdvanceManualClock(100),
+      CheckNewAnswer(7),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(5).json)
+        inputData.commits.clear()
+      },
+      AddData(inputData, 8),
+      AdvanceManualClock(100),
+      CheckNewAnswer(8),
+      AddData(inputData, 9),
+      AdvanceManualClock(800), //  persist offset
+      CheckNewAnswer(9),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5, 6, 9))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5, 6, 9))
+
+    // simulate batch 9 doesn't have a commit log
+    new File(checkpointLocation + "/commits/9").delete()
+    new File(checkpointLocation + "/commits/.9.crc").delete()
+
+    /**
+     * Turn async progress tracking off
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 10),
+      CheckNewAnswer(7, 8, 9, 10),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(9).json)
+        inputData.commits.clear()
+      },
+      AddData(inputData, 11),
+      CheckNewAnswer(11),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(10).json)
+        inputData.commits.clear()
+      },
+      StopStream
+    )
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5, 6, 9, 10, 11))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5, 6, 9, 10, 11))
+  }
+
+  test("Fail on stateful pipelines") {
+    val rateStream = spark.readStream
+      .format("rate")
+      .option("numPartitions", 1)
+      .option("rowsPerSecond", 10)
+      .load()
+      .toDF()
+
+    val windowedStream = rateStream
+      .withWatermark("timestamp", "0 seconds")
+      .groupBy(window(column("timestamp"), "10 seconds"), column("value"))
+      .count()
+
+    val e = intercept[StreamingQueryException] {
+      val query = windowedStream.writeStream
+        .format("noop")
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+
+      query.processAllAvailable()
+    }
+    e.getMessage should include(
+      "Stateful streaming queries does not support async progress tracking at this moment."
+    )
+  }
+
+  test("Fail on pipelines using unsupported sinks") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("parquet")
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(
+          "checkpointLocation",
+          Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+        )
+        .start("/tmp")
+    }
+
+    e.getMessage should equal("Sink FileSink[/tmp] does not support async progress tracking")
+  }
+
+  test("with log purging") {
+
+    withSQLConf(SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2", SQLConf.ASYNC_LOG_PURGE.key -> "false") {
+      withTempDir { checkpointLocation =>
+        val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+        val ds = inputData.toDS()
+
+        val clock = new StreamManualClock
+        testStream(
+          ds,
+          extraOptions = Map(
+            ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+            ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+          )
+        )(
+          // need to to set processing time to something so manual clock will work
+          StartStream(
+            Trigger.ProcessingTime("1 millisecond"),
+            checkpointLocation = checkpointLocation.getCanonicalPath,
+            triggerClock = clock
+          ),
+          AddData(inputData, 0),
+          AdvanceManualClock(100),
+          CheckNewAnswer(0),
+          AddData(inputData, 1),
+          AdvanceManualClock(100),
+          CheckNewAnswer(1),
+          AddData(inputData, 2),
+          AdvanceManualClock(100),
+          CheckNewAnswer(2),
+          AddData(inputData, 3),
+          AdvanceManualClock(800), // should trigger offset commit write to durable storage
+          CheckNewAnswer(3),
+          AddData(inputData, 4),
+          AdvanceManualClock(100),
+          CheckNewAnswer(4),
+          AddData(inputData, 5),
+          AdvanceManualClock(100),
+          CheckNewAnswer(5),
+          AddData(inputData, 6),
+          AdvanceManualClock(100),
+          CheckNewAnswer(6),
+          AddData(inputData, 7),
+          AdvanceManualClock(800), // should trigger offset commit write to durable storage
+          CheckNewAnswer(7),
+          Execute { q =>
+            // wait for all async log writes to finish
+            waitPendingOffsetWrites(q)
+          },
+          // add a new row to make sure log purge has kicked in.
+          // There can be a race condition in which the commit log entry for the previous batch
+          // may or may not be written to disk yet before the log purge is called.
+          // Adding another batch here will make sure purge is called on the correct number of
+          // offset and commit log entries
+          AddData(inputData, 8),
+          AdvanceManualClock(100),
+          CheckNewAnswer(8),
+          Execute { q =>
+            waitPendingOffsetWrites(q)
+            getListOfFiles(checkpointLocation + "/offsets")
+              .filter(file => !file.isHidden)
+              .map(file => file.getName.toInt)
+              .sorted should equal(Array(3, 7))
+            getListOfFiles(checkpointLocation + "/commits")
+              .filter(file => !file.isHidden)
+              .map(file => file.getName.toInt)
+              .sorted should equal(Array(3, 7))
+            q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].writtenToDurableStorage.size() should be(2)
+            q.commitLog.asInstanceOf[AsyncCommitLog].writtenToDurableStorage.size() should be(2)
+          },
+          StopStream
+        )
+
+        /**
+         * restart
+         */
+        testStream(
+          ds,
+          extraOptions = Map(
+            ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+            ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+          )
+        )(
+          // need to to set processing time to something so manual clock will work
+          StartStream(
+            Trigger.ProcessingTime("1 millisecond"),
+            checkpointLocation = checkpointLocation.getCanonicalPath,
+            triggerClock = clock
+          ),
+          AddData(inputData, 9),
+          AdvanceManualClock(100),
+          CheckNewAnswer(8, 9),
+          AddData(inputData, 10),
+          AdvanceManualClock(100),
+          CheckNewAnswer(10),
+          AddData(inputData, 11),
+          AdvanceManualClock(100),
+          CheckNewAnswer(11),
+          AddData(inputData, 12),
+          AdvanceManualClock(800), // should trigger offset commit write to durable storage
+          CheckNewAnswer(12),
+          Execute { q =>
+            // wait for all async log writes to finish
+            waitPendingOffsetWrites(q)
+          },
+          // add a new row to make sure log purge has kicked in.
+          // There can be a race condition in which the commit log entry for the previous batch
+          // may or may not be written to disk yet before the log purge is called.
+          // Adding another batch here will make sure purge is called on the correct number of
+          // offset and commit log entries
+          AddData(inputData, 13),
+          AdvanceManualClock(100),
+          CheckNewAnswer(13),
+          Execute { q =>
+            waitPendingOffsetWrites(q)
+            getListOfFiles(checkpointLocation + "/offsets")

Review Comment:
   sure



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048929851


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncCommitLog.scala:
##########
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.OutputStream
+import java.util.concurrent.{CompletableFuture, ConcurrentLinkedDeque, ThreadPoolExecutor}
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.SparkSession
+
+/**
+ * Implementation of CommitLog to perform asynchronous writes to storage
+ */
+class AsyncCommitLog(sparkSession: SparkSession, path: String, executorService: ThreadPoolExecutor)
+    extends CommitLog(sparkSession, path) {
+

Review Comment:
   It is missing but we have it in our internal implementation. Not sure how this was missed. will add back



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1044068977


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())

Review Comment:
   This does not seem to be safe. Please see the code comment.
   
   ```
           // Now that we've updated the scheduler's persistent checkpoint, it is safe for the
           // sources to discard data from the previous batch.
           if (currentBatchId != 0) {
             val prevBatchOff = offsetLog.get(currentBatchId - 1)
             if (prevBatchOff.isDefined) {
               prevBatchOff.get.toStreamProgress(sources).foreach {
                 case (src: Source, off: Offset) => src.commit(off)
                 case (stream: MicroBatchStream, off) =>
                   stream.commit(stream.deserializeOffset(off.json))
                 case (src, _) =>
                   throw new IllegalArgumentException(
                     s"Unknown source is found at constructNextBatch: $src")
               }
             } else {
               throw new IllegalStateException(s"batch ${currentBatchId - 1} doesn't exist")
             }
           }
   ```
   
   This says, we never go back to batch (currentBatchId - 1) to reprocess, hence the data bound to the batches lower or equal to the batch (currentBatchId - 1) can be cleaned up from source. (depending on the implementation of source)
   
   Applying this here, we have to check for commit log to see which batch ID we committed and never be processed in future. And commit to source should happen with offset log for such batch ID.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on PR #38517:
URL: https://github.com/apache/spark/pull/38517#issuecomment-1336784998

   cc. @zsxwing @viirya @xuanyuanking to seek a chance for getting help on reviewing. I'll look into the PR sooner as well.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048788954


##########
connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala:
##########
@@ -195,6 +200,102 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase {
     true
   }
 
+  /**
+   * Test async progress tracking capability with Kafka source and sink
+   */
+  test("async progress tracking") {
+    val inputTopic = newTopic()
+    testUtils.createTopic(inputTopic, partitions = 5)
+
+    val dataSent = new ListBuffer[String]()
+    testUtils.sendMessages(inputTopic, (0 until 15).map { case x =>
+      val m = s"foo-$x"
+      dataSent += m
+      m
+    }.toArray, Some(0))
+
+    val outputTopic = newTopic()
+    testUtils.createTopic(outputTopic, partitions = 5)
+
+    withTempDir { dir =>
+      val reader = spark
+        .readStream
+        .format("kafka")
+        .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+        .option("kafka.metadata.max.age.ms", "1")
+        .option("maxOffsetsPerTrigger", 5)
+        .option("subscribe", inputTopic)
+        .option("startingOffsets", "earliest")
+        .load()
+
+      def startQuery(): StreamingQuery = {
+        reader.writeStream
+          .format("kafka")
+          .option("checkpointLocation", dir.getCanonicalPath)
+          .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+          .option("kafka.max.block.ms", "5000")
+          .option("topic", outputTopic)
+          .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+          .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+          .queryName("kafkaStream")
+          .start()
+      }
+
+      def readResults(): ListBuffer[String] = {
+        val data = new ListBuffer[String]()
+        val readQuery = spark.readStream
+          .format("kafka")
+          .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+          .option("startingOffsets", "earliest")
+          .option("subscribe", outputTopic)
+          .load().writeStream

Review Comment:
   sure



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049249895


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {

Review Comment:
   sure



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048934794


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"

Review Comment:
   ok



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048957081


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())
+            if (prevBatchOff.isDefined) {
+              // Offset is ready to be committed by the source. Add to queue
+              sourceCommitQueue.add(prevBatchOff.get)
+            } else {
+              throw new IllegalStateException(
+                s"batch ${lastBatchPersistedToDurableStorage.get()} doesn't exist"
+              )
+            }
+          }
+          lastBatchPersistedToDurableStorage.set(batchId)
+        }
+      })
+      .exceptionally((th: Throwable) => {
+        logError("Encountered error while performing async offset write", th)
+        errorNotifier.markError(th)
+        return
+      })
+
+    // check if there are offsets that are ready to be committed by the source
+    var offset = sourceCommitQueue.poll()
+    while (offset != null) {
+      commitSources(offset)
+      offset = sourceCommitQueue.poll()
+    }
+  }
+
+  override def markMicroBatchEnd(): Unit = {
+    watermarkTracker.updateWatermark(lastExecution.executedPlan)
+    reportTimeTaken("commitOffsets") {
+      // check if current batch there is a async write for the offset log is issued for this batch
+      // if so, we should do the same for commit log
+      if (!offsetLog.getAsyncOffsetWrite(currentBatchId).isEmpty) {
+        commitLog
+          .addAsync(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))
+          .exceptionally((th: Throwable) => {
+            logError("Got exception during async write", th)

Review Comment:
   ok



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049011457


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala:
##########
@@ -342,17 +342,14 @@ class MicroBatchExecution(
         isCurrentBatchConstructed = true
         availableOffsets = nextOffsets.toStreamProgress(sources)
         /* Initialize committed offsets to a committed batch, which at this
-         * is the second latest batch id in the offset log. */
-        if (latestBatchId != 0) {
-          val secondLatestOffsets = offsetLog.get(latestBatchId - 1).getOrElse {
-            logError(s"The offset log for batch ${latestBatchId - 1} doesn't exist, " +
-              s"which is required to restart the query from the latest batch $latestBatchId " +
-              "from the offset log. Please ensure there are two subsequent offset logs " +
-              "available for the latest batch via manually deleting the offset file(s). " +
-              "Please also ensure the latest batch for commit log is equal or one batch " +
-              "earlier than the latest batch for offset log.")
-            throw new IllegalStateException(s"batch ${latestBatchId - 1} doesn't exist")
-          }
+         * is the second latest batch id in the offset log.

Review Comment:
   > Do we have a goal to support smooth transition between normal microbatch execution and async progress tracking for a single query?
   
   yes
   
   The existing behavior will break async progress tracking especially if the user wants to switch between turning it on and off



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049019969


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {

Review Comment:
   ok



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {

Review Comment:
   ok



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049024355


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {

Review Comment:
   sure



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049071577


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty

Review Comment:
   This is correct. If you look further down, we are deleting all offset log files except for 2 and all of the commit log files



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049018591


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage

Review Comment:
   The offsets are checked right after this.  This is more for waiting for things to be done.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049015251


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet

Review Comment:
   ok



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {

Review Comment:
   ok



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049011990


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala:
##########
@@ -342,17 +342,14 @@ class MicroBatchExecution(
         isCurrentBatchConstructed = true
         availableOffsets = nextOffsets.toStreamProgress(sources)
         /* Initialize committed offsets to a committed batch, which at this
-         * is the second latest batch id in the offset log. */
-        if (latestBatchId != 0) {
-          val secondLatestOffsets = offsetLog.get(latestBatchId - 1).getOrElse {
-            logError(s"The offset log for batch ${latestBatchId - 1} doesn't exist, " +
-              s"which is required to restart the query from the latest batch $latestBatchId " +
-              "from the offset log. Please ensure there are two subsequent offset logs " +
-              "available for the latest batch via manually deleting the offset file(s). " +
-              "Please also ensure the latest batch for commit log is equal or one batch " +
-              "earlier than the latest batch for offset log.")
-            throw new IllegalStateException(s"batch ${latestBatchId - 1} doesn't exist")
-          }
+         * is the second latest batch id in the offset log.
+         * The offset log may not be contiguous */
+        val prevBatchId = offsetLog.getPrevBatchFromStorage(latestBatchId)
+        if (latestBatchId != 0 && prevBatchId.isDefined) {
+            val secondLatestOffsets = offsetLog.get(prevBatchId.get).getOrElse({

Review Comment:
   will fix



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052636003


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala:
##########
@@ -342,17 +342,14 @@ class MicroBatchExecution(
         isCurrentBatchConstructed = true
         availableOffsets = nextOffsets.toStreamProgress(sources)
         /* Initialize committed offsets to a committed batch, which at this
-         * is the second latest batch id in the offset log. */
-        if (latestBatchId != 0) {
-          val secondLatestOffsets = offsetLog.get(latestBatchId - 1).getOrElse {
-            logError(s"The offset log for batch ${latestBatchId - 1} doesn't exist, " +
-              s"which is required to restart the query from the latest batch $latestBatchId " +
-              "from the offset log. Please ensure there are two subsequent offset logs " +
-              "available for the latest batch via manually deleting the offset file(s). " +
-              "Please also ensure the latest batch for commit log is equal or one batch " +
-              "earlier than the latest batch for offset log.")
-            throw new IllegalStateException(s"batch ${latestBatchId - 1} doesn't exist")
-          }
+         * is the second latest batch id in the offset log.

Review Comment:
   This logic is not guaranteeing exactly once behavior.  This logic here is merely a sanity check to guard against bugs.  Not having this check is not breaking exactly once behavior. 
   
   
   > Wouldn't it be more serious problem than supporting switch? If we really want to support switching, can we only support switch for the case when checkpoint interval is disabled, so that we don't make change on normal microbatch execution which could lead to break on fault tolerance semantic?
   
   I don't quite follow. How would this work? The framework will have to somehow remember the settings of the previous run.  We would need to add metadata to offsets to determine which offsets were written when async progress tracking is used.  We don't have any of this kind of functionality in spark today and not convinced it is worth while to implement such a thing for this use case.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052963923


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -275,7 +298,7 @@ object AsyncProgressTrackingMicroBatchExecution {
     "_asyncProgressTrackingOverrideSinkSupportCheck"
 
   private def getAsyncProgressTrackingCheckpointingIntervalMs(
-                                                               extraOptions: Map[String, String]): Long = {
+    extraOptions: Map[String, String]): Long = {

Review Comment:
   will fix



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052534190


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncOffsetSeqLog.scala:
##########
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.OutputStream
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.util.{Clock, SystemClock}
+
+/**
+ * Used to write entries to the offset log asynchronously
+ */
+class AsyncOffsetSeqLog(
+  sparkSession: SparkSession,

Review Comment:
   will fix



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] AmplabJenkins commented on pull request #38517: [WIP][SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
AmplabJenkins commented on PR #38517:
URL: https://github.com/apache/spark/pull/38517#issuecomment-1304497085

   Can one of the admins verify this patch?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1043434280


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())
+            if (prevBatchOff.isDefined) {
+              // Offset is ready to be committed by the source. Add to queue
+              sourceCommitQueue.add(prevBatchOff.get)
+            } else {
+              throw new IllegalStateException(
+                s"batch ${lastBatchPersistedToDurableStorage.get()} doesn't exist"
+              )
+            }
+          }
+          lastBatchPersistedToDurableStorage.set(batchId)
+        }
+      })
+      .exceptionally((th: Throwable) => {
+        logError("Encountered error while performing async offset write", th)
+        errorNotifier.markError(th)
+        return
+      })
+
+    // check if there are offsets that are ready to be committed by the source
+    var offset = sourceCommitQueue.poll()
+    while (offset != null) {
+      commitSources(offset)
+      offset = sourceCommitQueue.poll()
+    }
+  }
+
+  override def markMicroBatchEnd(): Unit = {
+    watermarkTracker.updateWatermark(lastExecution.executedPlan)
+    reportTimeTaken("commitOffsets") {
+      // check if current batch there is a async write for the offset log is issued for this batch
+      // if so, we should do the same for commit log
+      if (!offsetLog.getAsyncOffsetWrite(currentBatchId).isEmpty) {
+        commitLog
+          .addAsync(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))
+          .exceptionally((th: Throwable) => {
+            logError("Got exception during async write", th)
+            errorNotifier.markError(th)
+            return
+          })
+      } else {
+        if (!commitLog.addInMemory(
+          currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))) {
+          throw new IllegalStateException(
+            s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId"
+          )
+        }
+      }
+      offsetLog.removeAsyncOffsetWrite(currentBatchId)
+    }
+    committedOffsets ++= availableOffsets
+  }
+
+  // need to look at the number of files on disk
+  override def purge(threshold: Long): Unit = {
+    while (offsetLog.writtenToDurableStorage.size() > minLogEntriesToMaintain) {
+      offsetLog.writtenToDurableStorage.poll()
+    }
+    offsetLog.purge(offsetLog.writtenToDurableStorage.peek())
+
+    while (commitLog.writtenToDurableStorage.size() > minLogEntriesToMaintain) {
+      commitLog.writtenToDurableStorage.poll()
+    }
+    commitLog.purge(commitLog.writtenToDurableStorage.peek())
+  }
+
+  override def cleanup(): Unit = {
+    super.cleanup()
+
+    ThreadUtils.shutdown(asyncWritesExecutorService)
+    logInfo(s"Async progress tracking executor pool for query ${prettyIdString} has been shutdown")
+  }
+
+  // used for testing
+  def areWritesPendingOrInProgress(): Boolean = {
+    asyncWritesExecutorService.getQueue.size() > 0 || asyncWritesExecutorService.getActiveCount > 0
+  }
+
+  private def validateAndGetTrigger(): TriggerExecutor = {
+    // validate that the pipeline is using a supported sink
+    if (!extraOptions
+      .get(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK

Review Comment:
   OK I read the code backward and it's explicitly mentioned that the option is test-only.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1043380933


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful

Review Comment:
   There seems to be a term issue.
   
   https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html
   
   We have zero mention of pipeline in the guide doc. It's just a "query", specifically, "streaming query".
   
   While you are changing this, please also change the method name below as well.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049028252


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(

Review Comment:
   This is related to async progress tracking change



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049019214


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(

Review Comment:
   will fix



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049017655


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(

Review Comment:
   There are subtle differences for methods that have this name.  I think its more readable if it is around the context of the test that uses it.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049014916


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))

Review Comment:
   ok



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049014598


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {

Review Comment:
   The test should time out right?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049013658


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)

Review Comment:
   There are no default values for the parameters of MemoryStream



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049012589


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest

Review Comment:
   will fix



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049070974


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {

Review Comment:
   Its testing the new semantics that async progress tracking brought



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049079011


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/offsets/3").delete()
+    new File(checkpointLocation + "/offsets/.3.crc").delete()
+    new File(checkpointLocation + "/offsets/4").delete()
+    new File(checkpointLocation + "/offsets/.4.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.4.crc").delete()
+    new File(checkpointLocation + "/commits/4").delete()
+    new File(checkpointLocation + "/commits/.5.crc").delete()
+    new File(checkpointLocation + "/commits/5").delete()
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      AddData(inputData2, 3),
+      AddData(inputData2, 4),
+      AddData(inputData2, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3, 4, 5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+  }
+
+  test(
+    "recovery when gaps in in offset and" +
+    " commit log"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/commits/2").delete()
+    new File(checkpointLocation + "/commits/.2.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from batch 1
+      CheckNewAnswer(1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+  }
+
+  test("recovery non-contiguous log") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // delete batch 3 from commit log
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * Turn async progress tracking off and test recovery
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 4),
+      CheckNewAnswer(1, 2, 3, 4),

Review Comment:
   I don't quite follow,  one batch would 1 -> 3 and one for 3 -> 4



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052636553


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala:
##########
@@ -342,17 +342,14 @@ class MicroBatchExecution(
         isCurrentBatchConstructed = true
         availableOffsets = nextOffsets.toStreamProgress(sources)
         /* Initialize committed offsets to a committed batch, which at this
-         * is the second latest batch id in the offset log. */
-        if (latestBatchId != 0) {
-          val secondLatestOffsets = offsetLog.get(latestBatchId - 1).getOrElse {
-            logError(s"The offset log for batch ${latestBatchId - 1} doesn't exist, " +
-              s"which is required to restart the query from the latest batch $latestBatchId " +
-              "from the offset log. Please ensure there are two subsequent offset logs " +
-              "available for the latest batch via manually deleting the offset file(s). " +
-              "Please also ensure the latest batch for commit log is equal or one batch " +
-              "earlier than the latest batch for offset log.")
-            throw new IllegalStateException(s"batch ${latestBatchId - 1} doesn't exist")
-          }
+         * is the second latest batch id in the offset log.

Review Comment:
   How often do we actually catch issues with this check?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052663782


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),

Review Comment:
   Yes its the former i.e. Spark will run a single microbatch as batch 2 including previous available data as well



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052722369


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala:
##########
@@ -342,17 +342,14 @@ class MicroBatchExecution(
         isCurrentBatchConstructed = true
         availableOffsets = nextOffsets.toStreamProgress(sources)
         /* Initialize committed offsets to a committed batch, which at this
-         * is the second latest batch id in the offset log. */
-        if (latestBatchId != 0) {
-          val secondLatestOffsets = offsetLog.get(latestBatchId - 1).getOrElse {
-            logError(s"The offset log for batch ${latestBatchId - 1} doesn't exist, " +
-              s"which is required to restart the query from the latest batch $latestBatchId " +
-              "from the offset log. Please ensure there are two subsequent offset logs " +
-              "available for the latest batch via manually deleting the offset file(s). " +
-              "Please also ensure the latest batch for commit log is equal or one batch " +
-              "earlier than the latest batch for offset log.")
-            throw new IllegalStateException(s"batch ${latestBatchId - 1} doesn't exist")
-          }
+         * is the second latest batch id in the offset log.

Review Comment:
   How often we encounter this is not important. The fact we are removing guard which prevents breaking fault-tolerance semantic is important.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048933229


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncOffsetSeqLog.scala:
##########
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.OutputStream
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.util.{Clock, SystemClock}
+
+/**
+ * Used to write entries to the offset log asynchronously
+ */
+class AsyncOffsetSeqLog(
+    sparkSession: SparkSession,
+    path: String,
+    executorService: ThreadPoolExecutor,
+    offsetCommitIntervalMs: Long,
+    clock: Clock = new SystemClock())
+    extends OffsetSeqLog(sparkSession, path) {
+
+  // the cache needs to be enabled because we may not be persisting every entry to durable storage
+  // entries not persisted to durable storage will just be stored in memory for faster lookups
+  assert(metadataCacheEnabled == true)
+
+  // A map of the current pending offset writes. Key -> batch Id, Value -> CompletableFuture
+  // Used to determine if a commit log entry for this batch also needs to be persisted to storage
+  private val pendingOffsetWrites = new ConcurrentHashMap[Long, CompletableFuture[Long]]()
+
+  // Keeps track the last time a commit was issued. Used for issuing commits to storage at
+  // the configured intervals
+  private val lastCommitIssuedTimestampMs: AtomicLong = new AtomicLong(-1)
+
+  // A queue of batches written to storage.  Used to keep track when to purge old batches
+  val writtenToDurableStorage =
+    new ConcurrentLinkedDeque[Long](listBatchesOnDisk.toList.asJavaCollection)
+
+  /**
+   * Get a async offset write by batch id.  To check if a corresponding commit log entry
+   * needs to be written to durable storage as well
+   * @param batchId
+   * @return a option to indicate whether a async offset write was issued for the batch with id
+   */
+  def getAsyncOffsetWrite(batchId: Long): Option[CompletableFuture[Long]] = {
+    Option(pendingOffsetWrites.get(batchId))
+  }
+
+  /**
+   * Remove the async offset write when we don't need to keep track of it anymore
+   * @param batchId
+   */
+  def removeAsyncOffsetWrite(batchId: Long): Unit = {
+    pendingOffsetWrites.remove(batchId)
+  }
+
+  /**
+   * Writes a new batch to the offset log asynchronously
+   * @param batchId id of batch to write
+   * @param metadata metadata of batch to write
+   * @return a CompeletableFuture that contains the batch id.  The future is completed when
+   *         the async write of the batch is completed.  Future may also be completed exceptionally
+   *         to indicate some write error.
+   */
+  def addAsync(batchId: Long, metadata: OffsetSeq): CompletableFuture[(Long, Boolean)] = {
+    require(metadata != null, "'null' metadata cannot written to a metadata log")
+
+    def issueAsyncWrite(batchId: Long): CompletableFuture[Long] = {
+      lastCommitIssuedTimestampMs.set(clock.getTimeMillis())
+      val future: CompletableFuture[Long] = addNewBatchByStreamAsync(batchId) { output =>
+        serialize(metadata, output)
+      }.thenApply((ret: Boolean) => {
+        if (ret) {
+          batchId
+        } else {
+          throw new IllegalStateException(
+            s"Concurrent update to the log. Multiple streaming jobs detected for $batchId"
+          )
+        }
+      })
+      pendingOffsetWrites.put(batchId, future)
+      future
+    }
+
+    val lastIssuedTs = lastCommitIssuedTimestampMs.get()
+    val future: CompletableFuture[(Long, Boolean)] = {
+      if (offsetCommitIntervalMs > 0) {
+        if ((lastIssuedTs == -1) // haven't started any commits yet
+          || (lastIssuedTs + offsetCommitIntervalMs) <= clock.getTimeMillis()) {
+          issueAsyncWrite(batchId).thenApply((batchId: Long) => {
+            (batchId, true)
+          })
+        } else {
+          // just return completed future because we are not persisting this offset
+          CompletableFuture.completedFuture((batchId, false))
+        }
+      } else {
+        // offset commit interval is not enabled
+        issueAsyncWrite(batchId).thenApply((batchId: Long) => {
+          (batchId, true)
+        })
+      }
+    }
+
+    batchCache.put(batchId, metadata)
+    future
+  }
+
+  /**
+   * Adds new batch asynchronously
+   * @param batchId id of batch to write
+   * @param fn serialization function
+   * @return CompletableFuture that contains a boolean do
+   *         indicate whether the write was successfuly or not.
+   *         Future can also be completed exceptionally to indicate write errors.
+   */
+  private def addNewBatchByStreamAsync(batchId: Long)(
+      fn: OutputStream => Unit): CompletableFuture[Boolean] = {
+    val future = new CompletableFuture[Boolean]()
+    val batchMetadataFile = batchIdToPath(batchId)
+
+    if (batchCache.containsKey(batchId)) {
+      future.complete(false)
+      future
+    } else {
+      executorService.submit(new Runnable {
+        override def run(): Unit = {
+          try {
+            if (fileManager.exists(batchMetadataFile)) {
+              future.complete(false)
+            } else {
+              val start = System.currentTimeMillis()
+              write(
+                batchMetadataFile,
+                fn
+              )
+              logDebug(
+                s"Offset commit for batch ${batchId} took" +
+                s" ${System.currentTimeMillis() - start} ms to be persisted to durable storage"
+              )
+              writtenToDurableStorage.add(batchId)
+              future.complete(true)
+            }
+          } catch {
+            case e: Throwable =>
+              logError(s"Encountered error while writing batch ${batchId} to offset log", e)
+              future.completeExceptionally(e)
+          }
+        }
+      })
+      future
+    }
+  }
+
+  /**
+   * Purge entries in the offset log up to thresholdBatchId.  This method is synchronized so that

Review Comment:
   This log is out of date.  We don't need synchronization.  Purges always delete old entries and writes will only append new entries.  They should not conflict



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncOffsetSeqLog.scala:
##########
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.OutputStream
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.util.{Clock, SystemClock}
+
+/**
+ * Used to write entries to the offset log asynchronously
+ */
+class AsyncOffsetSeqLog(
+    sparkSession: SparkSession,
+    path: String,
+    executorService: ThreadPoolExecutor,
+    offsetCommitIntervalMs: Long,
+    clock: Clock = new SystemClock())
+    extends OffsetSeqLog(sparkSession, path) {
+
+  // the cache needs to be enabled because we may not be persisting every entry to durable storage
+  // entries not persisted to durable storage will just be stored in memory for faster lookups
+  assert(metadataCacheEnabled == true)
+
+  // A map of the current pending offset writes. Key -> batch Id, Value -> CompletableFuture
+  // Used to determine if a commit log entry for this batch also needs to be persisted to storage
+  private val pendingOffsetWrites = new ConcurrentHashMap[Long, CompletableFuture[Long]]()
+
+  // Keeps track the last time a commit was issued. Used for issuing commits to storage at
+  // the configured intervals
+  private val lastCommitIssuedTimestampMs: AtomicLong = new AtomicLong(-1)
+
+  // A queue of batches written to storage.  Used to keep track when to purge old batches
+  val writtenToDurableStorage =
+    new ConcurrentLinkedDeque[Long](listBatchesOnDisk.toList.asJavaCollection)
+
+  /**
+   * Get a async offset write by batch id.  To check if a corresponding commit log entry
+   * needs to be written to durable storage as well
+   * @param batchId
+   * @return a option to indicate whether a async offset write was issued for the batch with id
+   */
+  def getAsyncOffsetWrite(batchId: Long): Option[CompletableFuture[Long]] = {
+    Option(pendingOffsetWrites.get(batchId))
+  }
+
+  /**
+   * Remove the async offset write when we don't need to keep track of it anymore
+   * @param batchId
+   */
+  def removeAsyncOffsetWrite(batchId: Long): Unit = {
+    pendingOffsetWrites.remove(batchId)
+  }
+
+  /**
+   * Writes a new batch to the offset log asynchronously
+   * @param batchId id of batch to write
+   * @param metadata metadata of batch to write
+   * @return a CompeletableFuture that contains the batch id.  The future is completed when
+   *         the async write of the batch is completed.  Future may also be completed exceptionally
+   *         to indicate some write error.
+   */
+  def addAsync(batchId: Long, metadata: OffsetSeq): CompletableFuture[(Long, Boolean)] = {
+    require(metadata != null, "'null' metadata cannot written to a metadata log")
+
+    def issueAsyncWrite(batchId: Long): CompletableFuture[Long] = {
+      lastCommitIssuedTimestampMs.set(clock.getTimeMillis())
+      val future: CompletableFuture[Long] = addNewBatchByStreamAsync(batchId) { output =>
+        serialize(metadata, output)
+      }.thenApply((ret: Boolean) => {
+        if (ret) {
+          batchId
+        } else {
+          throw new IllegalStateException(
+            s"Concurrent update to the log. Multiple streaming jobs detected for $batchId"
+          )
+        }
+      })
+      pendingOffsetWrites.put(batchId, future)
+      future
+    }
+
+    val lastIssuedTs = lastCommitIssuedTimestampMs.get()
+    val future: CompletableFuture[(Long, Boolean)] = {
+      if (offsetCommitIntervalMs > 0) {
+        if ((lastIssuedTs == -1) // haven't started any commits yet
+          || (lastIssuedTs + offsetCommitIntervalMs) <= clock.getTimeMillis()) {
+          issueAsyncWrite(batchId).thenApply((batchId: Long) => {
+            (batchId, true)
+          })
+        } else {
+          // just return completed future because we are not persisting this offset
+          CompletableFuture.completedFuture((batchId, false))
+        }
+      } else {
+        // offset commit interval is not enabled
+        issueAsyncWrite(batchId).thenApply((batchId: Long) => {
+          (batchId, true)
+        })
+      }
+    }
+
+    batchCache.put(batchId, metadata)
+    future
+  }
+
+  /**
+   * Adds new batch asynchronously
+   * @param batchId id of batch to write
+   * @param fn serialization function
+   * @return CompletableFuture that contains a boolean do
+   *         indicate whether the write was successfuly or not.
+   *         Future can also be completed exceptionally to indicate write errors.
+   */
+  private def addNewBatchByStreamAsync(batchId: Long)(
+      fn: OutputStream => Unit): CompletableFuture[Boolean] = {
+    val future = new CompletableFuture[Boolean]()
+    val batchMetadataFile = batchIdToPath(batchId)
+
+    if (batchCache.containsKey(batchId)) {
+      future.complete(false)
+      future
+    } else {
+      executorService.submit(new Runnable {
+        override def run(): Unit = {
+          try {
+            if (fileManager.exists(batchMetadataFile)) {
+              future.complete(false)
+            } else {
+              val start = System.currentTimeMillis()
+              write(
+                batchMetadataFile,
+                fn
+              )
+              logDebug(
+                s"Offset commit for batch ${batchId} took" +
+                s" ${System.currentTimeMillis() - start} ms to be persisted to durable storage"
+              )
+              writtenToDurableStorage.add(batchId)
+              future.complete(true)
+            }
+          } catch {
+            case e: Throwable =>
+              logError(s"Encountered error while writing batch ${batchId} to offset log", e)
+              future.completeExceptionally(e)
+          }
+        }
+      })
+      future
+    }
+  }
+
+  /**
+   * Purge entries in the offset log up to thresholdBatchId.  This method is synchronized so that

Review Comment:
   I will edit the comment



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048937234


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+

Review Comment:
   sure



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052963344


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -1762,4 +1344,173 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
       }
     }
   }
+
+  test("test gaps in offset log") {
+    val inputData = MemoryStream[Int]
+    val streamEvent = inputData.toDF().select("value")
+
+    val resourceUri = this.getClass.getResource(
+      "/structured-streaming/checkpoint-test-offsetId-commitId-inconsistent/").toURI
+    val checkpointDir = Utils.createTempDir().getCanonicalFile
+    // Copy the checkpoint to a temp dir to prevent changes to the original.
+    // Not doing this will lead to the test passing on the first run, but fail subsequent runs.
+    FileUtils.copyDirectory(new File(resourceUri), checkpointDir)
+
+    // Not doing this will lead to the test passing on the first run, but fail subsequent runs.
+    FileUtils.copyDirectory(new File(resourceUri), checkpointDir)
+
+    testStream(streamEvent, extraOptions = Map(
+      ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+      ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+    ))(
+      AddData(inputData, 0),
+      AddData(inputData, 1),
+      AddData(inputData, 2),
+      AddData(inputData, 3),
+      AddData(inputData, 4),
+      StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+      CheckAnswer(3, 4)
+    )
+
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only
+    // contains batch 0, 2, 5 and commit log only contain 0, 2
+    testStream(ds, extraOptions = Map(
+      ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+      ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+    ))(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    // delete all offset files except for batch 0, 2, 5
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filterNot(f => f.getName.startsWith("0")
+        || f.getName.startsWith("2")
+        || f.getName.startsWith("5"))
+      .foreach(_.delete())
+
+    // delete all commit log files except for batch 0, 2
+    getListOfFiles(checkpointLocation + "/commits")
+      .filterNot(f => f.getName.startsWith("0") || f.getName.startsWith("2"))
+      .foreach(_.delete())
+
+    getBatchIdsSortedFromLog(checkpointLocation + "/offsets") should equal(Array(0, 2, 5))
+    getBatchIdsSortedFromLog(checkpointLocation + "/commits") should equal(Array(0, 2))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2, extraOptions = Map(
+      ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+      ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+    ))(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      AddData(inputData2, 3),
+      AddData(inputData2, 4),
+      AddData(inputData2, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // since the offset log contains batches 0, 2, 5 and the commit log contains
+      // batches 0, 2.  This indicates that batch we have successfully processed up to batch 2.
+      // Thus the data we need to process / re-process is batches 3, 4, 5
+      CheckNewAnswer(3, 4, 5),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        eventually(timeout(Span(5, Seconds))) {
+          getBatchIdsSortedFromLog(checkpointLocation + "/offsets") should equal(Array(0, 2, 5))
+          getBatchIdsSortedFromLog(checkpointLocation + "/commits") should equal(Array(0, 2, 5))
+        }
+      },
+      StopStream
+    )
+
+    getBatchIdsSortedFromLog(checkpointLocation + "/offsets") should equal(Array(0, 2, 5))
+    getBatchIdsSortedFromLog(checkpointLocation + "/commits") should equal(Array(0, 2, 5))
+  }
+
+  test("recovery when gaps exist in offset and commit log") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only
+    // contains batch 0, 2 and commit log only contains 9

Review Comment:
   will fix



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052931956


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -275,7 +298,7 @@ object AsyncProgressTrackingMicroBatchExecution {
     "_asyncProgressTrackingOverrideSinkSupportCheck"
 
   private def getAsyncProgressTrackingCheckpointingIntervalMs(
-                                                               extraOptions: Map[String, String]): Long = {
+    extraOptions: Map[String, String]): Long = {

Review Comment:
   nit: 4 spaces



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1050436774


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,285 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+  sparkSession: SparkSession,

Review Comment:
   nit: same - 4 spaces for params, 2 spaces for extends.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,285 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+  sparkSession: SparkSession,
+  trigger: Trigger,
+  triggerClock: Clock,
+  extraOptions: Map[String, String],
+  plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  import AsyncProgressTrackingMicroBatchExecution._
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  // used to check during the first batch if the pipeline is stateful
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if streamign query is stateful

Review Comment:
   nit: streaming



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1765 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.File
+import java.util.concurrent.{CountDownLatch, Semaphore, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.TestUtils
+import org.apache.spark.sql._
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.Utils
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+  extends StreamTest
+    with BeforeAndAfter

Review Comment:
   nit: with should be same indentation with extends



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1765 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.File
+import java.util.concurrent.{CountDownLatch, Semaphore, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.TestUtils
+import org.apache.spark.sql._
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.Utils
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+  extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 21 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 21 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  def testAsyncWriteErrorsAlreadyExists(path: String): Unit = {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + path))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            TestUtils.assertExceptionMsg(e,
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 1:" +
+    " offset file already exists for a batch") {
+    testAsyncWriteErrorsAlreadyExists("/offsets/1")
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1:" +
+    " commit file already exists for a batch") {
+    testAsyncWriteErrorsAlreadyExists("/commits/1")
+  }
+
+  def testAsyncWriteErrorsPermissionsIssue(path: String): Unit = {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + path)
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+    // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2:" +
+    " cannot write offset files due to permissions issue") {
+    testAsyncWriteErrorsPermissionsIssue("/offsets")
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2" +
+    ": commit file already exists for a batch") {
+    testAsyncWriteErrorsPermissionsIssue("/commits")
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+    extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach{ thread =>
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    }
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test("recovery when first offset is not zero and not commit log entries") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only
+    // contains batch 0, 2, 5 and commit log only contain 0, 2
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/offsets/3").delete()
+    new File(checkpointLocation + "/offsets/.3.crc").delete()
+    new File(checkpointLocation + "/offsets/4").delete()
+    new File(checkpointLocation + "/offsets/.4.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.4.crc").delete()
+    new File(checkpointLocation + "/commits/4").delete()
+    new File(checkpointLocation + "/commits/.5.crc").delete()
+    new File(checkpointLocation + "/commits/5").delete()
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      AddData(inputData2, 3),
+      AddData(inputData2, 4),
+      AddData(inputData2, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // since
+      CheckNewAnswer(3, 4, 5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+  }
+
+  test("recovery when gaps in in offset and commit log") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only
+    // contains batch 0, 2 and commit log only contains 9
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/commits/2").delete()
+    new File(checkpointLocation + "/commits/.2.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from batch 1
+      CheckNewAnswer(1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+  }
+
+  test("recovery non-contiguous log") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // delete batch 3 from commit log
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * Turn async progress tracking off and test recovery
+     */
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 4),
+      CheckNewAnswer(1, 2, 3, 4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+    // batches 0 and 3 should be logged

Review Comment:
   nit: comment seems to be out of sync



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {

Review Comment:
   I don't see any difference but won't nitpick anymore as this is probably the area of preference. Spark codebase tends to ask for inlining the method which is used for only once, but I agree there is counter case that not-inlining is better for readability.



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")

Review Comment:
   ```
   getListOfFiles(checkpointLocation + "/offsets")
     .filter(file => !file.isHidden)
     .map(file => file.getName.toInt)
     .sorted
   ```
   
   ```
   getListOfFiles(checkpointLocation + "/commits")
     .filter(file => !file.isHidden)
     .map(file => file.getName.toInt)
     .sorted
   ```
   
   Do you see the pattern? I wasn't referring to "offsets" vs "commits" directory. I was referring to the pattern 1) listing files and 2) filtering out hidden files, and 3) converting the file name to the number, and 4) performing reverse sort. The pattern is repeated more than 10 times across test cases.



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),

Review Comment:
   > yes though if there are gaps in the offset log, more than 1 batch may be replayed
   
   Yeah I wanted to know about the details here; how it works for this test case? Would Spark run a single microbatch as batch 2 including previous available data as well, or would Spark run two different microbatches, pre-batch 2 and batch 2? Looks like it's former, but wanted to be sure.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala:
##########
@@ -342,17 +342,14 @@ class MicroBatchExecution(
         isCurrentBatchConstructed = true
         availableOffsets = nextOffsets.toStreamProgress(sources)
         /* Initialize committed offsets to a committed batch, which at this
-         * is the second latest batch id in the offset log. */
-        if (latestBatchId != 0) {
-          val secondLatestOffsets = offsetLog.get(latestBatchId - 1).getOrElse {
-            logError(s"The offset log for batch ${latestBatchId - 1} doesn't exist, " +
-              s"which is required to restart the query from the latest batch $latestBatchId " +
-              "from the offset log. Please ensure there are two subsequent offset logs " +
-              "available for the latest batch via manually deleting the offset file(s). " +
-              "Please also ensure the latest batch for commit log is equal or one batch " +
-              "earlier than the latest batch for offset log.")
-            throw new IllegalStateException(s"batch ${latestBatchId - 1} doesn't exist")
-          }
+         * is the second latest batch id in the offset log.

Review Comment:
   I got back here again as I feel like it's not safe to do the change in normal microbatch execution. 
   
   Say, if the user doesn't use async progress tracking at all. (Using sink which is not listed in async progress tracking support or using stateful operator are all valid.) And suppose there is an issue on storage and there are gaps on offset log somehow.
   
   Previously, the query will fail fast. It will stop processing, but it won't break anything.
   
   Now, we allow query to go through, with silently breaking semantic of exactly once. Idempotent write based on batch ID will be broken as offset range for the batch ID gets changed. Same for state store.
   
   Wouldn't it be more serious problem than supporting switch? If we really want to support switching, can we only support switch for the case when checkpoint interval is disabled, so that we don't make change on normal microbatch execution which could lead to break on fault tolerance semantic?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncCommitLog.scala:
##########
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.OutputStream
+import java.util.concurrent.{CompletableFuture, ConcurrentLinkedDeque, ThreadPoolExecutor}
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.SparkSession
+
+/**
+ * Implementation of CommitLog to perform asynchronous writes to storage
+ */
+class AsyncCommitLog(sparkSession: SparkSession, path: String, executorService: ThreadPoolExecutor)
+  extends CommitLog(sparkSession, path) {
+
+  // the cache needs to be enabled because we may not be persisting every entry to durable storage
+  // entries not persisted to durable storage will just be stored in memory for faster lookups
+  assert(metadataCacheEnabled == true)
+
+  // A queue of batches written to storage.  Used to keep track when to purge old batches
+  val writtenToDurableStorage =
+    new ConcurrentLinkedDeque[Long](listBatchesOnDisk.toList.asJavaCollection)
+
+  /**
+   * Writes a new batch to the commit log asynchronously
+   * @param batchId id of batch to write
+   * @param metadata metadata of batch to write
+   * @return a CompeletableFuture that contains the batch id.  The future is completed when
+   *         the async write of the batch is completed.  Future may also be completed exceptionally
+   *         to indicate some write error.
+   */
+  def addAsync(batchId: Long, metadata: CommitMetadata): CompletableFuture[Long] = {
+    require(metadata != null, "'null' metadata cannot be written to a metadata log")
+    val future: CompletableFuture[Long] = addNewBatchByStreamAsync(batchId) { output =>
+      serialize(metadata, output)
+    }.thenApply((ret: Boolean) => {
+      if (ret) {
+        batchId
+      } else {
+        throw new IllegalStateException(
+          s"Concurrent update to the log. Multiple streaming jobs detected for $batchId"
+        )
+      }
+    })
+
+    batchCache.put(batchId, metadata)
+    future
+  }
+
+  /**
+   * Adds batch to commit log only in memory and not persisted to durable storage. This method is
+   * used when we don't want to persist the commit log entry for every micro batch
+   * to durable storage
+   * @param batchId id of batch to write
+   * @param metadata metadata of batch to write
+   * @return true if operation is successful otherwise false.
+   */
+  def addInMemory(batchId: Long, metadata: CommitMetadata): Boolean = {
+    if (batchCache.containsKey(batchId)) {
+      false
+    } else {
+      batchCache.put(batchId, metadata)
+      true
+    }
+  }
+
+  /**
+   * Purge entries in the commit log up to thresholdBatchId.
+   * @param thresholdBatchId
+   */
+  override def purge(thresholdBatchId: Long): Unit = {
+    super.purge(thresholdBatchId)
+  }
+
+  /**
+   * Adds new batch asynchronously
+   * @param batchId id of batch to write
+   * @param fn serialization function
+   * @return CompletableFuture that contains a boolean do
+   *         indicate whether the write was successfuly or not.
+   *         Future can also be completed exceptionally to indicate write errors.
+   */
+  private def addNewBatchByStreamAsync(batchId: Long)(
+    fn: OutputStream => Unit): CompletableFuture[Boolean] = {
+    val future = new CompletableFuture[Boolean]()
+    val batchMetadataFile = batchIdToPath(batchId)
+
+    if (batchCache.containsKey(batchId)) {
+      future.complete(false)
+      future
+    } else {
+      executorService.submit(new Runnable {
+        override def run(): Unit = {
+          try {
+            if (fileManager.exists(batchMetadataFile)) {
+              future.complete(false)
+            } else {
+              val start = System.currentTimeMillis()
+              write(
+                batchMetadataFile,
+                fn
+              )
+              logDebug(
+                s"Completion commit for batch${batchId} took" +

Review Comment:
   nit: space between batch and `${batchId}` to be consistent with below message in logError. 



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/offsets/3").delete()
+    new File(checkpointLocation + "/offsets/.3.crc").delete()
+    new File(checkpointLocation + "/offsets/4").delete()
+    new File(checkpointLocation + "/offsets/.4.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.4.crc").delete()
+    new File(checkpointLocation + "/commits/4").delete()
+    new File(checkpointLocation + "/commits/.5.crc").delete()
+    new File(checkpointLocation + "/commits/5").delete()
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      AddData(inputData2, 3),
+      AddData(inputData2, 4),
+      AddData(inputData2, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3, 4, 5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+  }
+
+  test(
+    "recovery when gaps in in offset and" +
+    " commit log"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/commits/2").delete()
+    new File(checkpointLocation + "/commits/.2.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from batch 1
+      CheckNewAnswer(1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+  }
+
+  test("recovery non-contiguous log") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // delete batch 3 from commit log
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * Turn async progress tracking off and test recovery
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 4),
+      CheckNewAnswer(1, 2, 3, 4),

Review Comment:
   So there are offsets 0 and 3 and commit 0. It's obvious that we don't replay batch 0. But we have to replay from the next of batch 0 to batch 3 we logged in offset log. From the data point of view, 1 to 2 (not logged in the offset log) to 3 (logged in the offset log for batch 3).
   
   If Spark processes 1 -> 3 and then 3 -> 4, 1 -> 3 would be batch 3, ignoring the offset range logged in offset log for batch 3, right? So when they switch to normal processing trigger, the fault tolerant semantic is still bound to the async progress tracking, "at least once", till the switched query gets processed and offset log becomes contiguous. It does not seem to be easy to reason about.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncOffsetSeqLog.scala:
##########
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.OutputStream
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.util.{Clock, SystemClock}
+
+/**
+ * Used to write entries to the offset log asynchronously
+ */
+class AsyncOffsetSeqLog(
+  sparkSession: SparkSession,

Review Comment:
   nit: 4 spaces for params, 2 spaces for extends. 
   
   https://github.com/databricks/scala-style-guide#indent
   
   > For classes whose header doesn't fit in two lines, use 4 space indentation for its parameters, put each in each line, put the extends on the next line with 2 space indent, and add a blank line after class header.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {

Review Comment:
   I meant placing class first, and then object. That pattern is widely applied to the codebase for class vs companion object.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())
+            if (prevBatchOff.isDefined) {
+              // Offset is ready to be committed by the source. Add to queue
+              sourceCommitQueue.add(prevBatchOff.get)
+            } else {
+              throw new IllegalStateException(
+                s"batch ${lastBatchPersistedToDurableStorage.get()} doesn't exist"
+              )
+            }
+          }
+          lastBatchPersistedToDurableStorage.set(batchId)
+        }
+      })
+      .exceptionally((th: Throwable) => {
+        logError("Encountered error while performing async offset write", th)
+        errorNotifier.markError(th)
+        return
+      })
+
+    // check if there are offsets that are ready to be committed by the source
+    var offset = sourceCommitQueue.poll()
+    while (offset != null) {
+      commitSources(offset)
+      offset = sourceCommitQueue.poll()
+    }
+  }
+
+  override def markMicroBatchEnd(): Unit = {
+    watermarkTracker.updateWatermark(lastExecution.executedPlan)
+    reportTimeTaken("commitOffsets") {
+      // check if current batch there is a async write for the offset log is issued for this batch
+      // if so, we should do the same for commit log
+      if (!offsetLog.getAsyncOffsetWrite(currentBatchId).isEmpty) {
+        commitLog
+          .addAsync(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))
+          .exceptionally((th: Throwable) => {
+            logError("Got exception during async write", th)
+            errorNotifier.markError(th)
+            return
+          })
+      } else {
+        if (!commitLog.addInMemory(
+          currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))) {
+          throw new IllegalStateException(
+            s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId"
+          )
+        }
+      }
+      offsetLog.removeAsyncOffsetWrite(currentBatchId)
+    }
+    committedOffsets ++= availableOffsets
+  }
+
+  // need to look at the number of files on disk
+  override def purge(threshold: Long): Unit = {
+    while (offsetLog.writtenToDurableStorage.size() > minLogEntriesToMaintain) {
+      offsetLog.writtenToDurableStorage.poll()
+    }
+    offsetLog.purge(offsetLog.writtenToDurableStorage.peek())
+
+    while (commitLog.writtenToDurableStorage.size() > minLogEntriesToMaintain) {
+      commitLog.writtenToDurableStorage.poll()
+    }
+    commitLog.purge(commitLog.writtenToDurableStorage.peek())
+  }
+
+  override def cleanup(): Unit = {
+    super.cleanup()
+
+    ThreadUtils.shutdown(asyncWritesExecutorService)
+    logInfo(s"Async progress tracking executor pool for query ${prettyIdString} has been shutdown")
+  }
+
+  // used for testing
+  def areWritesPendingOrInProgress(): Boolean = {
+    asyncWritesExecutorService.getQueue.size() > 0 || asyncWritesExecutorService.getActiveCount > 0
+  }
+
+  private def validateAndGetTrigger(): TriggerExecutor = {
+    // validate that the pipeline is using a supported sink
+    if (!extraOptions
+      .get(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK
+      )
+      .getOrElse("false")
+      .toBoolean) {
+      try {
+        plan.sink.name() match {
+          case "noop-table" =>
+          case "console" =>
+          case "MemorySink" =>
+          case "KafkaTable" =>
+          case _ =>
+            throw new IllegalArgumentException(
+              s"Sink ${plan.sink.name()}" +
+                s" does not support async progress tracking"
+            )
+        }
+      } catch {
+        case e: IllegalStateException =>
+          // sink does not implement name() method
+          if (e.getMessage.equals("should not be called.")) {

Review Comment:
   I actually meant that the sink throwing IllegalStateException on calling name() is not something async progress tracking needs to support, but it doesn't seem good to swallow the exception as well. I'm OK with leaving this as it is.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncOffsetSeqLog.scala:
##########
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.OutputStream
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.util.{Clock, SystemClock}
+
+/**
+ * Used to write entries to the offset log asynchronously
+ */
+class AsyncOffsetSeqLog(
+    sparkSession: SparkSession,
+    path: String,
+    executorService: ThreadPoolExecutor,
+    offsetCommitIntervalMs: Long,
+    clock: Clock = new SystemClock())
+    extends OffsetSeqLog(sparkSession, path) {
+
+  // the cache needs to be enabled because we may not be persisting every entry to durable storage
+  // entries not persisted to durable storage will just be stored in memory for faster lookups
+  assert(metadataCacheEnabled == true)
+
+  // A map of the current pending offset writes. Key -> batch Id, Value -> CompletableFuture
+  // Used to determine if a commit log entry for this batch also needs to be persisted to storage
+  private val pendingOffsetWrites = new ConcurrentHashMap[Long, CompletableFuture[Long]]()
+
+  // Keeps track the last time a commit was issued. Used for issuing commits to storage at
+  // the configured intervals
+  private val lastCommitIssuedTimestampMs: AtomicLong = new AtomicLong(-1)
+
+  // A queue of batches written to storage.  Used to keep track when to purge old batches
+  val writtenToDurableStorage =
+    new ConcurrentLinkedDeque[Long](listBatchesOnDisk.toList.asJavaCollection)
+
+  /**
+   * Get a async offset write by batch id.  To check if a corresponding commit log entry
+   * needs to be written to durable storage as well
+   * @param batchId
+   * @return a option to indicate whether a async offset write was issued for the batch with id
+   */
+  def getAsyncOffsetWrite(batchId: Long): Option[CompletableFuture[Long]] = {
+    Option(pendingOffsetWrites.get(batchId))
+  }
+
+  /**
+   * Remove the async offset write when we don't need to keep track of it anymore
+   * @param batchId
+   */
+  def removeAsyncOffsetWrite(batchId: Long): Unit = {
+    pendingOffsetWrites.remove(batchId)
+  }
+
+  /**
+   * Writes a new batch to the offset log asynchronously
+   * @param batchId id of batch to write
+   * @param metadata metadata of batch to write
+   * @return a CompeletableFuture that contains the batch id.  The future is completed when
+   *         the async write of the batch is completed.  Future may also be completed exceptionally
+   *         to indicate some write error.
+   */
+  def addAsync(batchId: Long, metadata: OffsetSeq): CompletableFuture[(Long, Boolean)] = {
+    require(metadata != null, "'null' metadata cannot written to a metadata log")
+
+    def issueAsyncWrite(batchId: Long): CompletableFuture[Long] = {
+      lastCommitIssuedTimestampMs.set(clock.getTimeMillis())
+      val future: CompletableFuture[Long] = addNewBatchByStreamAsync(batchId) { output =>
+        serialize(metadata, output)
+      }.thenApply((ret: Boolean) => {
+        if (ret) {
+          batchId
+        } else {
+          throw new IllegalStateException(
+            s"Concurrent update to the log. Multiple streaming jobs detected for $batchId"

Review Comment:
   OK, at least it is out of scope of this PR.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true

Review Comment:
   OK fair enough.



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)

Review Comment:
   MemoryStream has apply method in companion object.
   
   ```
   object MemoryStream {
     protected val currentBlockId = new AtomicInteger(0)
     protected val memoryStreamId = new AtomicInteger(0)
   
     def apply[A : Encoder](implicit sqlContext: SQLContext): MemoryStream[A] =
       new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext)
   
     def apply[A : Encoder](numPartitions: Int)(implicit sqlContext: SQLContext): MemoryStream[A] =
       new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext, Some(numPartitions))
   }
   ```
   
   Most of the cases you don't need any parameters.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncCommitLog.scala:
##########
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.OutputStream
+import java.util.concurrent.{CompletableFuture, ConcurrentLinkedDeque, ThreadPoolExecutor}
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.SparkSession
+
+/**
+ * Implementation of CommitLog to perform asynchronous writes to storage
+ */
+class AsyncCommitLog(sparkSession: SparkSession, path: String, executorService: ThreadPoolExecutor)
+    extends CommitLog(sparkSession, path) {
+
+  // A queue of batches written to storage.  Used to keep track when to purge old batches
+  val writtenToDurableStorage =
+    new ConcurrentLinkedDeque[Long](listBatchesOnDisk.toList.asJavaCollection)
+
+  /**
+   * Writes a new batch to the commit log asynchronously
+   * @param batchId id of batch to write
+   * @param metadata metadata of batch to write
+   * @return a CompeletableFuture that contains the batch id.  The future is completed when
+   *         the async write of the batch is completed.  Future may also be completed exceptionally
+   *         to indicate some write error.
+   */
+  def addAsync(batchId: Long, metadata: CommitMetadata): CompletableFuture[Long] = {
+    require(metadata != null, "'null' metadata cannot be written to a metadata log")
+    val future: CompletableFuture[Long] = addNewBatchByStreamAsync(batchId) { output =>
+      serialize(metadata, output)
+    }.thenApply((ret: Boolean) => {
+      if (ret) {
+        batchId
+      } else {
+        throw new IllegalStateException(
+          s"Concurrent update to the log. Multiple streaming jobs detected for $batchId"

Review Comment:
   Let's consider this as out of scope. I think it is still worth to do as it's not advisable for users to look into the cryptic stack trace, but it is fine for this PR as it is same as it was. 



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {

Review Comment:
   I'm not sure I understand. Where you expect timeout will happen? The test timeout from streaming test suites you may have seen is mostly under testStream(). 



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(

Review Comment:
   OK let's leave this as it is. I may revisit this eventually by myself to sort out.



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)

Review Comment:
   ```
   eventually(timeout(Span(5, Seconds))) {
         val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
           .filter(file => !file.isHidden)
           .map(file => file.getName.toInt)
           .sorted
   
         val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
           .filter(file => !file.isHidden)
           .map(file => file.getName.toInt)
           .sorted
   
         offsetLogFiles should equal (commitLogFiles)
       }
   ```
   
   This is used as an exit condition, so if we don't get test timeout, this should be already fulfilled.



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {

Review Comment:
   But we are making changes for normal processing time trigger, right? The behavior we test here is actually from the changes in MicroBatchExecution, not async one. The assumption of having at most one gap between offset and commit is no longer correct even for processing time trigger, and we are testing the new condition. Please keep the code change and the test suite closer to each other.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())

Review Comment:
   My bad. commit log is only used to workaround the edge case of specific DSv1 source (call getBatch() with the offset of "previous" batch).
   
   We are OK as long as it fits with below logic (determining which batch ID to restart from):
   
   ```
   val prevBatchId = offsetLog.getPrevBatchFromStorage(latestBatchId)
           if (latestBatchId != 0 && prevBatchId.isDefined) {
             val secondLatestOffsets = offsetLog.get(prevBatchId.get).getOrElse({
               throw new IllegalStateException(s"Offset metadata for batch ${prevBatchId}" +
                 s" cannot be found.  This should not happen.")
             })
             committedOffsets = secondLatestOffsets.toStreamProgress(sources)
           }
   ```



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()

Review Comment:
   If you imagine you're the reviewer of the code, then it's not easy to recognize that which batch the code is intentionally not removing.
   
   ```
   listFiles("offsets").filterNot(f => f.getName.startsWith("2") || f.getName.startsWith("5")).foreach(_.delete())
   listFiles("commits").filterNot(f => f.getName.startsWith("2")).foreach(_.delete())
   ```
   
   Above code would give a clearer indication that we intentionally leave 2 and 5 for offsets, and 2 for commit.
   
   That said, I'm also OK to just leave the code comment explicitly mentioning the intention.



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/offsets/3").delete()
+    new File(checkpointLocation + "/offsets/.3.crc").delete()
+    new File(checkpointLocation + "/offsets/4").delete()
+    new File(checkpointLocation + "/offsets/.4.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.4.crc").delete()
+    new File(checkpointLocation + "/commits/4").delete()
+    new File(checkpointLocation + "/commits/.5.crc").delete()
+    new File(checkpointLocation + "/commits/5").delete()
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      AddData(inputData2, 3),
+      AddData(inputData2, 4),
+      AddData(inputData2, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3, 4, 5),

Review Comment:
   > Which behavior is happening in underlying? Would Spark process two batches for 1) new batch (replacement of batch 3 and 4) and 2) batch 5 as recorded in offset log, or would Spark process just a single batch ignoring the offset range in offset log for batch 5?
   
   Can we comment clearly which one happens?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/offsets/3").delete()
+    new File(checkpointLocation + "/offsets/.3.crc").delete()
+    new File(checkpointLocation + "/offsets/4").delete()
+    new File(checkpointLocation + "/offsets/.4.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.4.crc").delete()
+    new File(checkpointLocation + "/commits/4").delete()
+    new File(checkpointLocation + "/commits/.5.crc").delete()
+    new File(checkpointLocation + "/commits/5").delete()
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      AddData(inputData2, 3),
+      AddData(inputData2, 4),
+      AddData(inputData2, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3, 4, 5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+  }
+
+  test(
+    "recovery when gaps in in offset and" +
+    " commit log"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/commits/2").delete()
+    new File(checkpointLocation + "/commits/.2.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from batch 1

Review Comment:
   The test never uses async progress tracking at all, and here we are changing the behavior, affecting fault tolerance semantic on normal processing trigger. I don't think it is something we want to do.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049028853


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +

Review Comment:
   ok



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049018217


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage

Review Comment:
   yes



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049094888


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/offsets/3").delete()
+    new File(checkpointLocation + "/offsets/.3.crc").delete()
+    new File(checkpointLocation + "/offsets/4").delete()
+    new File(checkpointLocation + "/offsets/.4.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.4.crc").delete()
+    new File(checkpointLocation + "/commits/4").delete()
+    new File(checkpointLocation + "/commits/.5.crc").delete()
+    new File(checkpointLocation + "/commits/5").delete()
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      AddData(inputData2, 3),
+      AddData(inputData2, 4),
+      AddData(inputData2, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3, 4, 5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+  }
+
+  test(
+    "recovery when gaps in in offset and" +
+    " commit log"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/commits/2").delete()
+    new File(checkpointLocation + "/commits/.2.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from batch 1
+      CheckNewAnswer(1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+  }
+
+  test("recovery non-contiguous log") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // delete batch 3 from commit log
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * Turn async progress tracking off and test recovery
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 4),
+      CheckNewAnswer(1, 2, 3, 4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+  }
+
+  test("switching async progress tracking with interval commits on and off") {
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    var clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    inputData.commits.length should be(0)
+
+    /**
+     * Turn async progress tracking off
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(3).json)
+        inputData.commits.clear()
+      },
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(4).json)
+        inputData.commits.clear()
+      },
+      StopStream
+    )
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+    /**
+     * Turn async progress tracking on
+     */
+    clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 6), // will get persisted since first batch on restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      AddData(inputData, 7),
+      AdvanceManualClock(100),
+      CheckNewAnswer(7),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(5).json)
+        inputData.commits.clear()
+      },
+      AddData(inputData, 8),
+      AdvanceManualClock(100),
+      CheckNewAnswer(8),
+      AddData(inputData, 9),
+      AdvanceManualClock(800), //  persist offset
+      CheckNewAnswer(9),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5, 6, 9))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5, 6, 9))
+
+    // simulate batch 9 doesn't have a commit log
+    new File(checkpointLocation + "/commits/9").delete()
+    new File(checkpointLocation + "/commits/.9.crc").delete()
+
+    /**
+     * Turn async progress tracking off
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 10),
+      CheckNewAnswer(7, 8, 9, 10),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(9).json)
+        inputData.commits.clear()
+      },
+      AddData(inputData, 11),
+      CheckNewAnswer(11),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(10).json)
+        inputData.commits.clear()
+      },
+      StopStream
+    )
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5, 6, 9, 10, 11))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5, 6, 9, 10, 11))
+  }
+
+  test("Fail on stateful pipelines") {
+    val rateStream = spark.readStream
+      .format("rate")
+      .option("numPartitions", 1)
+      .option("rowsPerSecond", 10)
+      .load()
+      .toDF()
+
+    val windowedStream = rateStream
+      .withWatermark("timestamp", "0 seconds")
+      .groupBy(window(column("timestamp"), "10 seconds"), column("value"))
+      .count()
+
+    val e = intercept[StreamingQueryException] {
+      val query = windowedStream.writeStream
+        .format("noop")
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+
+      query.processAllAvailable()
+    }
+    e.getMessage should include(
+      "Stateful streaming queries does not support async progress tracking at this moment."
+    )
+  }
+
+  test("Fail on pipelines using unsupported sinks") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("parquet")
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(
+          "checkpointLocation",
+          Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+        )
+        .start("/tmp")
+    }
+
+    e.getMessage should equal("Sink FileSink[/tmp] does not support async progress tracking")
+  }
+
+  test("with log purging") {
+
+    withSQLConf(SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2", SQLConf.ASYNC_LOG_PURGE.key -> "false") {
+      withTempDir { checkpointLocation =>
+        val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+        val ds = inputData.toDS()
+
+        val clock = new StreamManualClock
+        testStream(
+          ds,
+          extraOptions = Map(
+            ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+            ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+          )
+        )(
+          // need to to set processing time to something so manual clock will work
+          StartStream(
+            Trigger.ProcessingTime("1 millisecond"),
+            checkpointLocation = checkpointLocation.getCanonicalPath,
+            triggerClock = clock
+          ),
+          AddData(inputData, 0),
+          AdvanceManualClock(100),
+          CheckNewAnswer(0),
+          AddData(inputData, 1),
+          AdvanceManualClock(100),
+          CheckNewAnswer(1),
+          AddData(inputData, 2),
+          AdvanceManualClock(100),
+          CheckNewAnswer(2),
+          AddData(inputData, 3),
+          AdvanceManualClock(800), // should trigger offset commit write to durable storage
+          CheckNewAnswer(3),
+          AddData(inputData, 4),
+          AdvanceManualClock(100),
+          CheckNewAnswer(4),
+          AddData(inputData, 5),
+          AdvanceManualClock(100),
+          CheckNewAnswer(5),
+          AddData(inputData, 6),
+          AdvanceManualClock(100),
+          CheckNewAnswer(6),
+          AddData(inputData, 7),
+          AdvanceManualClock(800), // should trigger offset commit write to durable storage
+          CheckNewAnswer(7),
+          Execute { q =>
+            // wait for all async log writes to finish
+            waitPendingOffsetWrites(q)
+          },
+          // add a new row to make sure log purge has kicked in.
+          // There can be a race condition in which the commit log entry for the previous batch
+          // may or may not be written to disk yet before the log purge is called.
+          // Adding another batch here will make sure purge is called on the correct number of
+          // offset and commit log entries
+          AddData(inputData, 8),
+          AdvanceManualClock(100),
+          CheckNewAnswer(8),
+          Execute { q =>
+            waitPendingOffsetWrites(q)
+            getListOfFiles(checkpointLocation + "/offsets")

Review Comment:
   sure



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049070590


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),

Review Comment:
   > And is the behavior same with previous or it is a behavioral change for normal processing time trigger?
   
   yes though if there are gaps in the offset log, more than 1 batch may be replayed



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049032417


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))

Review Comment:
   The problem is the test will fail:
   
   ListBuffer(0, 1, 2, 3, 4, 5, 6, 7, 8, 9) did not equal Array(Range.Inclusive(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049077958


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/offsets/3").delete()
+    new File(checkpointLocation + "/offsets/.3.crc").delete()
+    new File(checkpointLocation + "/offsets/4").delete()
+    new File(checkpointLocation + "/offsets/.4.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.4.crc").delete()
+    new File(checkpointLocation + "/commits/4").delete()
+    new File(checkpointLocation + "/commits/.5.crc").delete()
+    new File(checkpointLocation + "/commits/5").delete()
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      AddData(inputData2, 3),
+      AddData(inputData2, 4),
+      AddData(inputData2, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3, 4, 5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+  }
+
+  test(
+    "recovery when gaps in in offset and" +
+    " commit log"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/commits/2").delete()
+    new File(checkpointLocation + "/commits/.2.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from batch 1
+      CheckNewAnswer(1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+  }
+
+  test("recovery non-contiguous log") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // delete batch 3 from commit log
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * Turn async progress tracking off and test recovery
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work

Review Comment:
   will remove



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049119941


##########
connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala:
##########
@@ -195,6 +200,102 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase {
     true
   }
 
+  /**
+   * Test async progress tracking capability with Kafka source and sink
+   */
+  test("async progress tracking") {
+    val inputTopic = newTopic()
+    testUtils.createTopic(inputTopic, partitions = 5)
+
+    val dataSent = new ListBuffer[String]()
+    testUtils.sendMessages(inputTopic, (0 until 15).map { case x =>
+      val m = s"foo-$x"
+      dataSent += m
+      m
+    }.toArray, Some(0))
+
+    val outputTopic = newTopic()
+    testUtils.createTopic(outputTopic, partitions = 5)
+
+    withTempDir { dir =>
+      val reader = spark
+        .readStream
+        .format("kafka")
+        .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+        .option("kafka.metadata.max.age.ms", "1")
+        .option("maxOffsetsPerTrigger", 5)
+        .option("subscribe", inputTopic)
+        .option("startingOffsets", "earliest")
+        .load()
+
+      def startQuery(): StreamingQuery = {
+        reader.writeStream
+          .format("kafka")
+          .option("checkpointLocation", dir.getCanonicalPath)
+          .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+          .option("kafka.max.block.ms", "5000")
+          .option("topic", outputTopic)
+          .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+          .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+          .queryName("kafkaStream")
+          .start()
+      }
+
+      def readResults(): ListBuffer[String] = {

Review Comment:
   The output would be same but the code and actual execution would be much simpler in batch query. See below code when we just go with batch query:
   
   ```
   spark.read
             .format("kafka")
             .option("kafka.bootstrap.servers", testUtils.brokerAddress)
             .option("startingOffsets", "earliest")
             .option("subscribe", outputTopic)
             .load()
             .select("CAST(value AS string)")
             .toDS()
             .collect()
             .map(_._1)
   ```
   
   The entire code in the method can be replaced with this query. Haven't gave a try but the actual code that could execute won't be much different.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049083624


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())
+            if (prevBatchOff.isDefined) {
+              // Offset is ready to be committed by the source. Add to queue
+              sourceCommitQueue.add(prevBatchOff.get)
+            } else {
+              throw new IllegalStateException(
+                s"batch ${lastBatchPersistedToDurableStorage.get()} doesn't exist"
+              )
+            }
+          }
+          lastBatchPersistedToDurableStorage.set(batchId)
+        }
+      })
+      .exceptionally((th: Throwable) => {
+        logError("Encountered error while performing async offset write", th)
+        errorNotifier.markError(th)
+        return
+      })
+
+    // check if there are offsets that are ready to be committed by the source
+    var offset = sourceCommitQueue.poll()
+    while (offset != null) {
+      commitSources(offset)
+      offset = sourceCommitQueue.poll()
+    }
+  }
+
+  override def markMicroBatchEnd(): Unit = {
+    watermarkTracker.updateWatermark(lastExecution.executedPlan)
+    reportTimeTaken("commitOffsets") {
+      // check if current batch there is a async write for the offset log is issued for this batch
+      // if so, we should do the same for commit log
+      if (!offsetLog.getAsyncOffsetWrite(currentBatchId).isEmpty) {
+        commitLog
+          .addAsync(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))
+          .exceptionally((th: Throwable) => {
+            logError("Got exception during async write", th)
+            errorNotifier.markError(th)
+            return
+          })
+      } else {
+        if (!commitLog.addInMemory(
+          currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))) {
+          throw new IllegalStateException(
+            s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId"

Review Comment:
   That's feasible for current MicroBatchExecution since multiple queries compete with specific path in external storage. The batch we add in memory only won't trigger this. We even don't share the memory cache for commit log across streaming queries in the driver.
   
   But I wouldn't mind to leave this as it is. We can just consider this as respecting previous interface with contract.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048960580


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())
+            if (prevBatchOff.isDefined) {
+              // Offset is ready to be committed by the source. Add to queue
+              sourceCommitQueue.add(prevBatchOff.get)
+            } else {
+              throw new IllegalStateException(
+                s"batch ${lastBatchPersistedToDurableStorage.get()} doesn't exist"
+              )
+            }
+          }
+          lastBatchPersistedToDurableStorage.set(batchId)
+        }
+      })
+      .exceptionally((th: Throwable) => {
+        logError("Encountered error while performing async offset write", th)
+        errorNotifier.markError(th)
+        return
+      })
+
+    // check if there are offsets that are ready to be committed by the source
+    var offset = sourceCommitQueue.poll()
+    while (offset != null) {
+      commitSources(offset)
+      offset = sourceCommitQueue.poll()
+    }
+  }
+
+  override def markMicroBatchEnd(): Unit = {
+    watermarkTracker.updateWatermark(lastExecution.executedPlan)
+    reportTimeTaken("commitOffsets") {
+      // check if current batch there is a async write for the offset log is issued for this batch
+      // if so, we should do the same for commit log
+      if (!offsetLog.getAsyncOffsetWrite(currentBatchId).isEmpty) {
+        commitLog
+          .addAsync(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))
+          .exceptionally((th: Throwable) => {
+            logError("Got exception during async write", th)
+            errorNotifier.markError(th)
+            return
+          })
+      } else {
+        if (!commitLog.addInMemory(
+          currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))) {
+          throw new IllegalStateException(
+            s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId"
+          )
+        }
+      }
+      offsetLog.removeAsyncOffsetWrite(currentBatchId)
+    }
+    committedOffsets ++= availableOffsets
+  }
+
+  // need to look at the number of files on disk
+  override def purge(threshold: Long): Unit = {
+    while (offsetLog.writtenToDurableStorage.size() > minLogEntriesToMaintain) {
+      offsetLog.writtenToDurableStorage.poll()
+    }
+    offsetLog.purge(offsetLog.writtenToDurableStorage.peek())
+
+    while (commitLog.writtenToDurableStorage.size() > minLogEntriesToMaintain) {
+      commitLog.writtenToDurableStorage.poll()
+    }
+    commitLog.purge(commitLog.writtenToDurableStorage.peek())
+  }
+
+  override def cleanup(): Unit = {
+    super.cleanup()
+
+    ThreadUtils.shutdown(asyncWritesExecutorService)
+    logInfo(s"Async progress tracking executor pool for query ${prettyIdString} has been shutdown")
+  }
+
+  // used for testing
+  def areWritesPendingOrInProgress(): Boolean = {
+    asyncWritesExecutorService.getQueue.size() > 0 || asyncWritesExecutorService.getActiveCount > 0
+  }
+
+  private def validateAndGetTrigger(): TriggerExecutor = {
+    // validate that the pipeline is using a supported sink
+    if (!extraOptions
+      .get(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK
+      )
+      .getOrElse("false")
+      .toBoolean) {
+      try {
+        plan.sink.name() match {
+          case "noop-table" =>
+          case "console" =>
+          case "MemorySink" =>
+          case "KafkaTable" =>
+          case _ =>
+            throw new IllegalArgumentException(
+              s"Sink ${plan.sink.name()}" +
+                s" does not support async progress tracking"
+            )
+        }
+      } catch {
+        case e: IllegalStateException =>
+          // sink does not implement name() method
+          if (e.getMessage.equals("should not be called.")) {
+            throw new IllegalArgumentException(
+              s"Sink ${plan.sink}" +
+                s" does not support async progress tracking"
+            )
+          } else {
+            throw e
+          }
+      }
+    }
+
+    trigger match {
+      case t: ProcessingTimeTrigger => ProcessingTimeExecutor(t, triggerClock)
+      case OneTimeTrigger =>
+        throw new IllegalArgumentException(
+          "Async progress tracking cannot be used with Once trigger")
+      case AvailableNowTrigger =>
+        throw new IllegalArgumentException(
+          "Async progress tracking cannot be used with AvailableNow trigger"
+        )
+      case _ => throw new IllegalStateException(s"Unknown type of trigger: $trigger")
+    }
+  }
+
+  private def checkNotStatefulPipeline: Unit = {
+    if (isFirstBatch) {
+      lastExecution.executedPlan.collect {
+        case p if p.isInstanceOf[StateStoreWriter] =>
+          throw new IllegalArgumentException(
+            "Stateful streaming queries does not support async progress tracking at this moment."
+          )
+          isFirstBatch = false

Review Comment:
   good catch!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048956305


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())
+            if (prevBatchOff.isDefined) {
+              // Offset is ready to be committed by the source. Add to queue
+              sourceCommitQueue.add(prevBatchOff.get)
+            } else {
+              throw new IllegalStateException(
+                s"batch ${lastBatchPersistedToDurableStorage.get()} doesn't exist"
+              )
+            }
+          }
+          lastBatchPersistedToDurableStorage.set(batchId)
+        }
+      })
+      .exceptionally((th: Throwable) => {
+        logError("Encountered error while performing async offset write", th)
+        errorNotifier.markError(th)
+        return
+      })
+
+    // check if there are offsets that are ready to be committed by the source
+    var offset = sourceCommitQueue.poll()
+    while (offset != null) {
+      commitSources(offset)
+      offset = sourceCommitQueue.poll()
+    }
+  }
+
+  override def markMicroBatchEnd(): Unit = {
+    watermarkTracker.updateWatermark(lastExecution.executedPlan)
+    reportTimeTaken("commitOffsets") {
+      // check if current batch there is a async write for the offset log is issued for this batch
+      // if so, we should do the same for commit log
+      if (!offsetLog.getAsyncOffsetWrite(currentBatchId).isEmpty) {

Review Comment:
   ok



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048947211


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())
+            if (prevBatchOff.isDefined) {
+              // Offset is ready to be committed by the source. Add to queue
+              sourceCommitQueue.add(prevBatchOff.get)
+            } else {
+              throw new IllegalStateException(
+                s"batch ${lastBatchPersistedToDurableStorage.get()} doesn't exist"
+              )
+            }
+          }
+          lastBatchPersistedToDurableStorage.set(batchId)
+        }
+      })
+      .exceptionally((th: Throwable) => {
+        logError("Encountered error while performing async offset write", th)

Review Comment:
   ok



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048929676


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncCommitLog.scala:
##########
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.OutputStream
+import java.util.concurrent.{CompletableFuture, ConcurrentLinkedDeque, ThreadPoolExecutor}
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.SparkSession
+
+/**
+ * Implementation of CommitLog to perform asynchronous writes to storage
+ */
+class AsyncCommitLog(sparkSession: SparkSession, path: String, executorService: ThreadPoolExecutor)
+    extends CommitLog(sparkSession, path) {
+
+  // A queue of batches written to storage.  Used to keep track when to purge old batches
+  val writtenToDurableStorage =
+    new ConcurrentLinkedDeque[Long](listBatchesOnDisk.toList.asJavaCollection)
+
+  /**
+   * Writes a new batch to the commit log asynchronously
+   * @param batchId id of batch to write
+   * @param metadata metadata of batch to write
+   * @return a CompeletableFuture that contains the batch id.  The future is completed when
+   *         the async write of the batch is completed.  Future may also be completed exceptionally
+   *         to indicate some write error.
+   */
+  def addAsync(batchId: Long, metadata: CommitMetadata): CompletableFuture[Long] = {
+    require(metadata != null, "'null' metadata cannot be written to a metadata log")
+    val future: CompletableFuture[Long] = addNewBatchByStreamAsync(batchId) { output =>
+      serialize(metadata, output)
+    }.thenApply((ret: Boolean) => {
+      if (ret) {
+        batchId
+      } else {
+        throw new IllegalStateException(
+          s"Concurrent update to the log. Multiple streaming jobs detected for $batchId"
+        )
+      }
+    })
+
+    future
+  }
+
+  /**
+   * Adds batch to commit log only in memory and not persisted to durable storage. This method is
+   * used when we don't want to persist the commit log entry for every micro batch
+   * to durable storage
+   * @param batchId id of batch to write
+   * @param metadata metadata of batch to write
+   * @return true if operation is successful otherwise false.
+   */
+  def addInMemory(batchId: Long, metadata: CommitMetadata): Boolean = {
+    if (batchCache.containsKey(batchId)) {
+      false
+    } else {
+      batchCache.put(batchId, metadata)
+      true
+    }
+  }
+
+  /**
+   * Purge entries in the commit log up to thresholdBatchId.  This method is synchronized so that

Review Comment:
   The parent class will remove the entry from batchCache



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1043131174


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())
+            if (prevBatchOff.isDefined) {
+              // Offset is ready to be committed by the source. Add to queue
+              sourceCommitQueue.add(prevBatchOff.get)
+            } else {
+              throw new IllegalStateException(
+                s"batch ${lastBatchPersistedToDurableStorage.get()} doesn't exist"
+              )
+            }
+          }
+          lastBatchPersistedToDurableStorage.set(batchId)
+        }
+      })
+      .exceptionally((th: Throwable) => {
+        logError("Encountered error while performing async offset write", th)
+        errorNotifier.markError(th)
+        return
+      })
+
+    // check if there are offsets that are ready to be committed by the source
+    var offset = sourceCommitQueue.poll()
+    while (offset != null) {
+      commitSources(offset)
+      offset = sourceCommitQueue.poll()
+    }
+  }
+
+  override def markMicroBatchEnd(): Unit = {
+    watermarkTracker.updateWatermark(lastExecution.executedPlan)
+    reportTimeTaken("commitOffsets") {
+      // check if current batch there is a async write for the offset log is issued for this batch
+      // if so, we should do the same for commit log
+      if (!offsetLog.getAsyncOffsetWrite(currentBatchId).isEmpty) {
+        commitLog
+          .addAsync(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))
+          .exceptionally((th: Throwable) => {
+            logError("Got exception during async write", th)
+            errorNotifier.markError(th)
+            return
+          })
+      } else {
+        if (!commitLog.addInMemory(
+          currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))) {
+          throw new IllegalStateException(
+            s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId"
+          )
+        }
+      }
+      offsetLog.removeAsyncOffsetWrite(currentBatchId)
+    }
+    committedOffsets ++= availableOffsets
+  }
+
+  // need to look at the number of files on disk
+  override def purge(threshold: Long): Unit = {
+    while (offsetLog.writtenToDurableStorage.size() > minLogEntriesToMaintain) {
+      offsetLog.writtenToDurableStorage.poll()
+    }
+    offsetLog.purge(offsetLog.writtenToDurableStorage.peek())
+
+    while (commitLog.writtenToDurableStorage.size() > minLogEntriesToMaintain) {
+      commitLog.writtenToDurableStorage.poll()
+    }
+    commitLog.purge(commitLog.writtenToDurableStorage.peek())
+  }
+
+  override def cleanup(): Unit = {
+    super.cleanup()
+
+    ThreadUtils.shutdown(asyncWritesExecutorService)
+    logInfo(s"Async progress tracking executor pool for query ${prettyIdString} has been shutdown")
+  }
+
+  // used for testing
+  def areWritesPendingOrInProgress(): Boolean = {
+    asyncWritesExecutorService.getQueue.size() > 0 || asyncWritesExecutorService.getActiveCount > 0
+  }
+
+  private def validateAndGetTrigger(): TriggerExecutor = {
+    // validate that the pipeline is using a supported sink
+    if (!extraOptions
+      .get(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK

Review Comment:
   Also, just to be clear, this option is for advanced usage and we won't document this explicitly, do I understand correctly?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())
+            if (prevBatchOff.isDefined) {
+              // Offset is ready to be committed by the source. Add to queue
+              sourceCommitQueue.add(prevBatchOff.get)
+            } else {
+              throw new IllegalStateException(
+                s"batch ${lastBatchPersistedToDurableStorage.get()} doesn't exist"
+              )
+            }
+          }
+          lastBatchPersistedToDurableStorage.set(batchId)
+        }
+      })
+      .exceptionally((th: Throwable) => {
+        logError("Encountered error while performing async offset write", th)
+        errorNotifier.markError(th)
+        return
+      })
+
+    // check if there are offsets that are ready to be committed by the source
+    var offset = sourceCommitQueue.poll()
+    while (offset != null) {
+      commitSources(offset)
+      offset = sourceCommitQueue.poll()
+    }
+  }
+
+  override def markMicroBatchEnd(): Unit = {
+    watermarkTracker.updateWatermark(lastExecution.executedPlan)
+    reportTimeTaken("commitOffsets") {
+      // check if current batch there is a async write for the offset log is issued for this batch
+      // if so, we should do the same for commit log
+      if (!offsetLog.getAsyncOffsetWrite(currentBatchId).isEmpty) {

Review Comment:
   nit: simply use `nonEmpty` rather than ! + isEmpty?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())
+            if (prevBatchOff.isDefined) {
+              // Offset is ready to be committed by the source. Add to queue
+              sourceCommitQueue.add(prevBatchOff.get)
+            } else {
+              throw new IllegalStateException(
+                s"batch ${lastBatchPersistedToDurableStorage.get()} doesn't exist"
+              )
+            }
+          }
+          lastBatchPersistedToDurableStorage.set(batchId)
+        }
+      })
+      .exceptionally((th: Throwable) => {
+        logError("Encountered error while performing async offset write", th)
+        errorNotifier.markError(th)
+        return
+      })
+
+    // check if there are offsets that are ready to be committed by the source
+    var offset = sourceCommitQueue.poll()
+    while (offset != null) {
+      commitSources(offset)
+      offset = sourceCommitQueue.poll()
+    }
+  }
+
+  override def markMicroBatchEnd(): Unit = {
+    watermarkTracker.updateWatermark(lastExecution.executedPlan)
+    reportTimeTaken("commitOffsets") {
+      // check if current batch there is a async write for the offset log is issued for this batch
+      // if so, we should do the same for commit log
+      if (!offsetLog.getAsyncOffsetWrite(currentBatchId).isEmpty) {
+        commitLog
+          .addAsync(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))
+          .exceptionally((th: Throwable) => {
+            logError("Got exception during async write", th)
+            errorNotifier.markError(th)
+            return
+          })
+      } else {
+        if (!commitLog.addInMemory(
+          currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))) {
+          throw new IllegalStateException(
+            s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId"

Review Comment:
   Let's mention "commit" log specifically.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())
+            if (prevBatchOff.isDefined) {
+              // Offset is ready to be committed by the source. Add to queue
+              sourceCommitQueue.add(prevBatchOff.get)
+            } else {
+              throw new IllegalStateException(
+                s"batch ${lastBatchPersistedToDurableStorage.get()} doesn't exist"
+              )
+            }
+          }
+          lastBatchPersistedToDurableStorage.set(batchId)
+        }
+      })
+      .exceptionally((th: Throwable) => {
+        logError("Encountered error while performing async offset write", th)

Review Comment:
   Let's add batch ID as well.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+

Review Comment:
   nit: maybe unnecessary empty new line



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful

Review Comment:
   There seems to be a term issue.
   
   https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html
   
   We have zero mention of pipeline in the guide doc. It's just a "query", specifically, "streaming query".



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true

Review Comment:
   The point here is not whether the batch is first one from current run. The point here is whether we checked the query is stateless already or not. Shall we rename this accordingly?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())
+            if (prevBatchOff.isDefined) {
+              // Offset is ready to be committed by the source. Add to queue
+              sourceCommitQueue.add(prevBatchOff.get)
+            } else {
+              throw new IllegalStateException(
+                s"batch ${lastBatchPersistedToDurableStorage.get()} doesn't exist"
+              )
+            }
+          }
+          lastBatchPersistedToDurableStorage.set(batchId)
+        }
+      })
+      .exceptionally((th: Throwable) => {
+        logError("Encountered error while performing async offset write", th)
+        errorNotifier.markError(th)
+        return
+      })
+
+    // check if there are offsets that are ready to be committed by the source
+    var offset = sourceCommitQueue.poll()
+    while (offset != null) {
+      commitSources(offset)
+      offset = sourceCommitQueue.poll()
+    }
+  }
+
+  override def markMicroBatchEnd(): Unit = {
+    watermarkTracker.updateWatermark(lastExecution.executedPlan)
+    reportTimeTaken("commitOffsets") {
+      // check if current batch there is a async write for the offset log is issued for this batch
+      // if so, we should do the same for commit log
+      if (!offsetLog.getAsyncOffsetWrite(currentBatchId).isEmpty) {
+        commitLog
+          .addAsync(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))
+          .exceptionally((th: Throwable) => {
+            logError("Got exception during async write", th)

Review Comment:
   Could we be more specific for the error log message? I expect the error message to provide the information that the error happened during write for commit log, and which batch ID.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())
+            if (prevBatchOff.isDefined) {
+              // Offset is ready to be committed by the source. Add to queue
+              sourceCommitQueue.add(prevBatchOff.get)
+            } else {
+              throw new IllegalStateException(
+                s"batch ${lastBatchPersistedToDurableStorage.get()} doesn't exist"

Review Comment:
   It'd be nice to be more specific. If we just see the error message in the exception, there is no information the batch ID does not exist from "where". It's more likely not user facing one but it's not friendly even to us.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {

Review Comment:
   nit: If the singleton object is purely for helper, maybe better to switch this with class definition, so that more important thing comes first.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+

Review Comment:
   nit: maybe better to import `AsyncProgressTrackingMicroBatchExecution._` and leverage it.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {

Review Comment:
   nit: 2 spaces here. Please refer the "class" part of indentation explanation.
   
   https://github.com/databricks/scala-style-guide#indent



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncCommitLog.scala:
##########
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.OutputStream
+import java.util.concurrent.{CompletableFuture, ConcurrentLinkedDeque, ThreadPoolExecutor}
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.SparkSession
+
+/**
+ * Implementation of CommitLog to perform asynchronous writes to storage
+ */
+class AsyncCommitLog(sparkSession: SparkSession, path: String, executorService: ThreadPoolExecutor)
+    extends CommitLog(sparkSession, path) {
+
+  // A queue of batches written to storage.  Used to keep track when to purge old batches
+  val writtenToDurableStorage =
+    new ConcurrentLinkedDeque[Long](listBatchesOnDisk.toList.asJavaCollection)
+
+  /**
+   * Writes a new batch to the commit log asynchronously
+   * @param batchId id of batch to write
+   * @param metadata metadata of batch to write
+   * @return a CompeletableFuture that contains the batch id.  The future is completed when
+   *         the async write of the batch is completed.  Future may also be completed exceptionally
+   *         to indicate some write error.
+   */
+  def addAsync(batchId: Long, metadata: CommitMetadata): CompletableFuture[Long] = {
+    require(metadata != null, "'null' metadata cannot be written to a metadata log")
+    val future: CompletableFuture[Long] = addNewBatchByStreamAsync(batchId) { output =>
+      serialize(metadata, output)
+    }.thenApply((ret: Boolean) => {
+      if (ret) {
+        batchId
+      } else {
+        throw new IllegalStateException(
+          s"Concurrent update to the log. Multiple streaming jobs detected for $batchId"

Review Comment:
   commit log



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,

Review Comment:
   nit: `AsyncProgressTrackingMicroBatchExecution.` looks to be unnecessary.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"

Review Comment:
   Please make sure this default value is documented as well when we document this.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncCommitLog.scala:
##########
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.OutputStream
+import java.util.concurrent.{CompletableFuture, ConcurrentLinkedDeque, ThreadPoolExecutor}
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.SparkSession
+
+/**
+ * Implementation of CommitLog to perform asynchronous writes to storage
+ */
+class AsyncCommitLog(sparkSession: SparkSession, path: String, executorService: ThreadPoolExecutor)
+    extends CommitLog(sparkSession, path) {
+
+  // A queue of batches written to storage.  Used to keep track when to purge old batches
+  val writtenToDurableStorage =
+    new ConcurrentLinkedDeque[Long](listBatchesOnDisk.toList.asJavaCollection)
+
+  /**
+   * Writes a new batch to the commit log asynchronously
+   * @param batchId id of batch to write
+   * @param metadata metadata of batch to write
+   * @return a CompeletableFuture that contains the batch id.  The future is completed when
+   *         the async write of the batch is completed.  Future may also be completed exceptionally
+   *         to indicate some write error.
+   */
+  def addAsync(batchId: Long, metadata: CommitMetadata): CompletableFuture[Long] = {
+    require(metadata != null, "'null' metadata cannot be written to a metadata log")
+    val future: CompletableFuture[Long] = addNewBatchByStreamAsync(batchId) { output =>
+      serialize(metadata, output)
+    }.thenApply((ret: Boolean) => {
+      if (ret) {
+        batchId
+      } else {
+        throw new IllegalStateException(
+          s"Concurrent update to the log. Multiple streaming jobs detected for $batchId"
+        )
+      }
+    })
+
+    future
+  }
+
+  /**
+   * Adds batch to commit log only in memory and not persisted to durable storage. This method is
+   * used when we don't want to persist the commit log entry for every micro batch
+   * to durable storage
+   * @param batchId id of batch to write
+   * @param metadata metadata of batch to write
+   * @return true if operation is successful otherwise false.
+   */
+  def addInMemory(batchId: Long, metadata: CommitMetadata): Boolean = {
+    if (batchCache.containsKey(batchId)) {
+      false
+    } else {
+      batchCache.put(batchId, metadata)
+      true
+    }
+  }
+
+  /**
+   * Purge entries in the commit log up to thresholdBatchId.  This method is synchronized so that

Review Comment:
   same here



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncOffsetSeqLog.scala:
##########
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.OutputStream
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.util.{Clock, SystemClock}
+
+/**
+ * Used to write entries to the offset log asynchronously
+ */
+class AsyncOffsetSeqLog(
+    sparkSession: SparkSession,
+    path: String,
+    executorService: ThreadPoolExecutor,
+    offsetCommitIntervalMs: Long,
+    clock: Clock = new SystemClock())
+    extends OffsetSeqLog(sparkSession, path) {
+
+  // the cache needs to be enabled because we may not be persisting every entry to durable storage
+  // entries not persisted to durable storage will just be stored in memory for faster lookups
+  assert(metadataCacheEnabled == true)
+
+  // A map of the current pending offset writes. Key -> batch Id, Value -> CompletableFuture
+  // Used to determine if a commit log entry for this batch also needs to be persisted to storage
+  private val pendingOffsetWrites = new ConcurrentHashMap[Long, CompletableFuture[Long]]()
+
+  // Keeps track the last time a commit was issued. Used for issuing commits to storage at
+  // the configured intervals
+  private val lastCommitIssuedTimestampMs: AtomicLong = new AtomicLong(-1)
+
+  // A queue of batches written to storage.  Used to keep track when to purge old batches
+  val writtenToDurableStorage =
+    new ConcurrentLinkedDeque[Long](listBatchesOnDisk.toList.asJavaCollection)
+
+  /**
+   * Get a async offset write by batch id.  To check if a corresponding commit log entry
+   * needs to be written to durable storage as well
+   * @param batchId
+   * @return a option to indicate whether a async offset write was issued for the batch with id
+   */
+  def getAsyncOffsetWrite(batchId: Long): Option[CompletableFuture[Long]] = {
+    Option(pendingOffsetWrites.get(batchId))
+  }
+
+  /**
+   * Remove the async offset write when we don't need to keep track of it anymore
+   * @param batchId
+   */
+  def removeAsyncOffsetWrite(batchId: Long): Unit = {
+    pendingOffsetWrites.remove(batchId)
+  }
+
+  /**
+   * Writes a new batch to the offset log asynchronously
+   * @param batchId id of batch to write
+   * @param metadata metadata of batch to write
+   * @return a CompeletableFuture that contains the batch id.  The future is completed when
+   *         the async write of the batch is completed.  Future may also be completed exceptionally
+   *         to indicate some write error.
+   */
+  def addAsync(batchId: Long, metadata: OffsetSeq): CompletableFuture[(Long, Boolean)] = {
+    require(metadata != null, "'null' metadata cannot written to a metadata log")
+
+    def issueAsyncWrite(batchId: Long): CompletableFuture[Long] = {
+      lastCommitIssuedTimestampMs.set(clock.getTimeMillis())
+      val future: CompletableFuture[Long] = addNewBatchByStreamAsync(batchId) { output =>
+        serialize(metadata, output)
+      }.thenApply((ret: Boolean) => {
+        if (ret) {
+          batchId
+        } else {
+          throw new IllegalStateException(
+            s"Concurrent update to the log. Multiple streaming jobs detected for $batchId"

Review Comment:
   Shall we be more specific for "where"? offset? commit?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncOffsetSeqLog.scala:
##########
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.OutputStream
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.util.{Clock, SystemClock}
+
+/**
+ * Used to write entries to the offset log asynchronously
+ */
+class AsyncOffsetSeqLog(
+    sparkSession: SparkSession,
+    path: String,
+    executorService: ThreadPoolExecutor,
+    offsetCommitIntervalMs: Long,
+    clock: Clock = new SystemClock())
+    extends OffsetSeqLog(sparkSession, path) {
+
+  // the cache needs to be enabled because we may not be persisting every entry to durable storage
+  // entries not persisted to durable storage will just be stored in memory for faster lookups
+  assert(metadataCacheEnabled == true)
+
+  // A map of the current pending offset writes. Key -> batch Id, Value -> CompletableFuture
+  // Used to determine if a commit log entry for this batch also needs to be persisted to storage
+  private val pendingOffsetWrites = new ConcurrentHashMap[Long, CompletableFuture[Long]]()
+
+  // Keeps track the last time a commit was issued. Used for issuing commits to storage at
+  // the configured intervals
+  private val lastCommitIssuedTimestampMs: AtomicLong = new AtomicLong(-1)
+
+  // A queue of batches written to storage.  Used to keep track when to purge old batches
+  val writtenToDurableStorage =
+    new ConcurrentLinkedDeque[Long](listBatchesOnDisk.toList.asJavaCollection)
+
+  /**
+   * Get a async offset write by batch id.  To check if a corresponding commit log entry
+   * needs to be written to durable storage as well
+   * @param batchId
+   * @return a option to indicate whether a async offset write was issued for the batch with id
+   */
+  def getAsyncOffsetWrite(batchId: Long): Option[CompletableFuture[Long]] = {
+    Option(pendingOffsetWrites.get(batchId))
+  }
+
+  /**
+   * Remove the async offset write when we don't need to keep track of it anymore
+   * @param batchId
+   */
+  def removeAsyncOffsetWrite(batchId: Long): Unit = {
+    pendingOffsetWrites.remove(batchId)
+  }
+
+  /**
+   * Writes a new batch to the offset log asynchronously
+   * @param batchId id of batch to write
+   * @param metadata metadata of batch to write
+   * @return a CompeletableFuture that contains the batch id.  The future is completed when
+   *         the async write of the batch is completed.  Future may also be completed exceptionally
+   *         to indicate some write error.
+   */
+  def addAsync(batchId: Long, metadata: OffsetSeq): CompletableFuture[(Long, Boolean)] = {
+    require(metadata != null, "'null' metadata cannot written to a metadata log")
+
+    def issueAsyncWrite(batchId: Long): CompletableFuture[Long] = {
+      lastCommitIssuedTimestampMs.set(clock.getTimeMillis())
+      val future: CompletableFuture[Long] = addNewBatchByStreamAsync(batchId) { output =>
+        serialize(metadata, output)
+      }.thenApply((ret: Boolean) => {
+        if (ret) {
+          batchId
+        } else {
+          throw new IllegalStateException(
+            s"Concurrent update to the log. Multiple streaming jobs detected for $batchId"
+          )
+        }
+      })
+      pendingOffsetWrites.put(batchId, future)
+      future
+    }
+
+    val lastIssuedTs = lastCommitIssuedTimestampMs.get()
+    val future: CompletableFuture[(Long, Boolean)] = {
+      if (offsetCommitIntervalMs > 0) {
+        if ((lastIssuedTs == -1) // haven't started any commits yet
+          || (lastIssuedTs + offsetCommitIntervalMs) <= clock.getTimeMillis()) {
+          issueAsyncWrite(batchId).thenApply((batchId: Long) => {
+            (batchId, true)
+          })
+        } else {
+          // just return completed future because we are not persisting this offset
+          CompletableFuture.completedFuture((batchId, false))
+        }
+      } else {
+        // offset commit interval is not enabled
+        issueAsyncWrite(batchId).thenApply((batchId: Long) => {
+          (batchId, true)
+        })
+      }
+    }
+
+    batchCache.put(batchId, metadata)
+    future
+  }
+
+  /**
+   * Adds new batch asynchronously
+   * @param batchId id of batch to write
+   * @param fn serialization function
+   * @return CompletableFuture that contains a boolean do
+   *         indicate whether the write was successfuly or not.
+   *         Future can also be completed exceptionally to indicate write errors.
+   */
+  private def addNewBatchByStreamAsync(batchId: Long)(
+      fn: OutputStream => Unit): CompletableFuture[Boolean] = {
+    val future = new CompletableFuture[Boolean]()
+    val batchMetadataFile = batchIdToPath(batchId)
+
+    if (batchCache.containsKey(batchId)) {
+      future.complete(false)
+      future
+    } else {
+      executorService.submit(new Runnable {
+        override def run(): Unit = {
+          try {
+            if (fileManager.exists(batchMetadataFile)) {
+              future.complete(false)
+            } else {
+              val start = System.currentTimeMillis()
+              write(
+                batchMetadataFile,
+                fn
+              )
+              logDebug(
+                s"Offset commit for batch ${batchId} took" +
+                s" ${System.currentTimeMillis() - start} ms to be persisted to durable storage"
+              )
+              writtenToDurableStorage.add(batchId)
+              future.complete(true)
+            }
+          } catch {
+            case e: Throwable =>
+              logError(s"Encountered error while writing batch ${batchId} to offset log", e)
+              future.completeExceptionally(e)
+          }
+        }
+      })
+      future
+    }
+  }
+
+  /**
+   * Purge entries in the offset log up to thresholdBatchId.  This method is synchronized so that

Review Comment:
   Is there synchronization between the two? I don't see any synchronization here. Am I missing something? 



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncOffsetSeqLog.scala:
##########
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.OutputStream
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.util.{Clock, SystemClock}
+
+/**
+ * Used to write entries to the offset log asynchronously
+ */
+class AsyncOffsetSeqLog(
+    sparkSession: SparkSession,
+    path: String,
+    executorService: ThreadPoolExecutor,
+    offsetCommitIntervalMs: Long,
+    clock: Clock = new SystemClock())
+    extends OffsetSeqLog(sparkSession, path) {

Review Comment:
   nit: 2 spaces



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncCommitLog.scala:
##########
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.OutputStream
+import java.util.concurrent.{CompletableFuture, ConcurrentLinkedDeque, ThreadPoolExecutor}
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.SparkSession
+
+/**
+ * Implementation of CommitLog to perform asynchronous writes to storage
+ */
+class AsyncCommitLog(sparkSession: SparkSession, path: String, executorService: ThreadPoolExecutor)
+    extends CommitLog(sparkSession, path) {

Review Comment:
   nit: 2 spaces



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())
+            if (prevBatchOff.isDefined) {
+              // Offset is ready to be committed by the source. Add to queue
+              sourceCommitQueue.add(prevBatchOff.get)
+            } else {
+              throw new IllegalStateException(
+                s"batch ${lastBatchPersistedToDurableStorage.get()} doesn't exist"
+              )
+            }
+          }
+          lastBatchPersistedToDurableStorage.set(batchId)
+        }
+      })
+      .exceptionally((th: Throwable) => {
+        logError("Encountered error while performing async offset write", th)
+        errorNotifier.markError(th)
+        return
+      })
+
+    // check if there are offsets that are ready to be committed by the source
+    var offset = sourceCommitQueue.poll()
+    while (offset != null) {
+      commitSources(offset)
+      offset = sourceCommitQueue.poll()
+    }
+  }
+
+  override def markMicroBatchEnd(): Unit = {
+    watermarkTracker.updateWatermark(lastExecution.executedPlan)
+    reportTimeTaken("commitOffsets") {
+      // check if current batch there is a async write for the offset log is issued for this batch
+      // if so, we should do the same for commit log
+      if (!offsetLog.getAsyncOffsetWrite(currentBatchId).isEmpty) {
+        commitLog
+          .addAsync(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))
+          .exceptionally((th: Throwable) => {
+            logError("Got exception during async write", th)
+            errorNotifier.markError(th)
+            return
+          })
+      } else {
+        if (!commitLog.addInMemory(
+          currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))) {
+          throw new IllegalStateException(
+            s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId"

Review Comment:
   Btw is it even possible to happen? If there are multiple queries based on same checkpoint running concurrently even in the same driver, the instance of AsyncCommitLog would be different.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncCommitLog.scala:
##########
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.OutputStream
+import java.util.concurrent.{CompletableFuture, ConcurrentLinkedDeque, ThreadPoolExecutor}
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.SparkSession
+
+/**
+ * Implementation of CommitLog to perform asynchronous writes to storage
+ */
+class AsyncCommitLog(sparkSession: SparkSession, path: String, executorService: ThreadPoolExecutor)
+    extends CommitLog(sparkSession, path) {
+

Review Comment:
   Does this class also require metadataCacheEnabled == true? If then let's also add require here as well.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncCommitLog.scala:
##########
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.OutputStream
+import java.util.concurrent.{CompletableFuture, ConcurrentLinkedDeque, ThreadPoolExecutor}
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.SparkSession
+
+/**
+ * Implementation of CommitLog to perform asynchronous writes to storage
+ */
+class AsyncCommitLog(sparkSession: SparkSession, path: String, executorService: ThreadPoolExecutor)
+    extends CommitLog(sparkSession, path) {
+
+  // A queue of batches written to storage.  Used to keep track when to purge old batches
+  val writtenToDurableStorage =
+    new ConcurrentLinkedDeque[Long](listBatchesOnDisk.toList.asJavaCollection)
+
+  /**
+   * Writes a new batch to the commit log asynchronously
+   * @param batchId id of batch to write
+   * @param metadata metadata of batch to write
+   * @return a CompeletableFuture that contains the batch id.  The future is completed when
+   *         the async write of the batch is completed.  Future may also be completed exceptionally
+   *         to indicate some write error.
+   */
+  def addAsync(batchId: Long, metadata: CommitMetadata): CompletableFuture[Long] = {
+    require(metadata != null, "'null' metadata cannot be written to a metadata log")
+    val future: CompletableFuture[Long] = addNewBatchByStreamAsync(batchId) { output =>
+      serialize(metadata, output)
+    }.thenApply((ret: Boolean) => {
+      if (ret) {
+        batchId
+      } else {
+        throw new IllegalStateException(
+          s"Concurrent update to the log. Multiple streaming jobs detected for $batchId"
+        )
+      }
+    })
+

Review Comment:
   Is updating batchCache missed or intended?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052639001


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1765 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.File
+import java.util.concurrent.{CountDownLatch, Semaphore, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.TestUtils
+import org.apache.spark.sql._
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.Utils
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+  extends StreamTest
+    with BeforeAndAfter

Review Comment:
   will fix



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052689380


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/offsets/3").delete()
+    new File(checkpointLocation + "/offsets/.3.crc").delete()
+    new File(checkpointLocation + "/offsets/4").delete()
+    new File(checkpointLocation + "/offsets/.4.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.4.crc").delete()
+    new File(checkpointLocation + "/commits/4").delete()
+    new File(checkpointLocation + "/commits/.5.crc").delete()
+    new File(checkpointLocation + "/commits/5").delete()
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      AddData(inputData2, 3),
+      AddData(inputData2, 4),
+      AddData(inputData2, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3, 4, 5),

Review Comment:
   will add a comment



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on PR #38517:
URL: https://github.com/apache/spark/pull/38517#issuecomment-1358900740

   Let me give +1 once the builds are passed rather than waiting for addressing all minor/nit comments. We can deal with them as a follow-up PR.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052648054


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {

Review Comment:
   ok I will add a timeout



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052728943


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)

Review Comment:
   OK.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052927996


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -78,7 +72,9 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
   test("async WAL commits happy path") {
     val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
 
-    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+//    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)

Review Comment:
   nit: remove comment



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -283,3 +264,23 @@ class AsyncProgressTrackingMicroBatchExecution(
     }
   }
 }
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+                                                               extraOptions: Map[String, String]): Long = {

Review Comment:
   nit: indentation is off



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -157,8 +172,17 @@ class AsyncProgressTrackingMicroBatchExecution(
     watermarkTracker.updateWatermark(lastExecution.executedPlan)
     reportTimeTaken("commitOffsets") {
       // check if current batch there is a async write for the offset log is issued for this batch
-      // if so, we should do the same for commit log
-      if (offsetLog.getAsyncOffsetWrite(currentBatchId).nonEmpty) {
+      // if so, we should do the same for commit log.  However, if this is the first batch executed
+      // in this run we should always persis to the commit log.  There can be situations in which

Review Comment:
   nit: persist



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052964187


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -157,8 +172,17 @@ class AsyncProgressTrackingMicroBatchExecution(
     watermarkTracker.updateWatermark(lastExecution.executedPlan)
     reportTimeTaken("commitOffsets") {
       // check if current batch there is a async write for the offset log is issued for this batch
-      // if so, we should do the same for commit log
-      if (offsetLog.getAsyncOffsetWrite(currentBatchId).nonEmpty) {
+      // if so, we should do the same for commit log.  However, if this is the first batch executed
+      // in this run we should always persis to the commit log.  There can be situations in which

Review Comment:
   will fix



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on PR #38517:
URL: https://github.com/apache/spark/pull/38517#issuecomment-1358951698

   https://github.com/jerrypeng/spark/actions/runs/3737519049 
   Above build passed for the last commit [23210ec](https://github.com/apache/spark/pull/38517/commits/23210ecea2bf9b39267613d6b4f356dd057cf890), and [03222c2](https://github.com/apache/spark/pull/38517/commits/03222c21f2f9d0829a0baf326d9a12626c3547d4) only addressed minors.
   
   Thanks! Merging to master.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on PR #38517:
URL: https://github.com/apache/spark/pull/38517#issuecomment-1336801690

   Also cc. @mridulm since he reviewed the design doc in details.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052590390


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())

Review Comment:
   Are we good here?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052664424


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {

Review Comment:
   ok



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052960918


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -68,11 +64,42 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
     }
   }
 
+  class MemoryStreamCapture[A: Encoder](
+                                         id: Int,

Review Comment:
   will fix



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1044075012


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {

Review Comment:
   Never mind there are test cases.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1052649920


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))

Review Comment:
   ok will change



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on PR #38517:
URL: https://github.com/apache/spark/pull/38517#issuecomment-1358899601

   (You can ignore outdated comments since I messed up with only seeing two recent commits and maybe left some comments which only bound to old commit.)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049096226


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/offsets/3").delete()
+    new File(checkpointLocation + "/offsets/.3.crc").delete()
+    new File(checkpointLocation + "/offsets/4").delete()
+    new File(checkpointLocation + "/offsets/.4.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.4.crc").delete()
+    new File(checkpointLocation + "/commits/4").delete()
+    new File(checkpointLocation + "/commits/.5.crc").delete()
+    new File(checkpointLocation + "/commits/5").delete()
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      AddData(inputData2, 3),
+      AddData(inputData2, 4),
+      AddData(inputData2, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3, 4, 5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+  }
+
+  test(
+    "recovery when gaps in in offset and" +
+    " commit log"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/commits/2").delete()
+    new File(checkpointLocation + "/commits/.2.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from batch 1
+      CheckNewAnswer(1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+  }
+
+  test("recovery non-contiguous log") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // delete batch 3 from commit log
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * Turn async progress tracking off and test recovery
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 4),
+      CheckNewAnswer(1, 2, 3, 4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+  }
+
+  test("switching async progress tracking with interval commits on and off") {
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    var clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    inputData.commits.length should be(0)
+
+    /**
+     * Turn async progress tracking off
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(3).json)
+        inputData.commits.clear()
+      },
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(4).json)
+        inputData.commits.clear()
+      },
+      StopStream
+    )
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+    /**
+     * Turn async progress tracking on
+     */
+    clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 6), // will get persisted since first batch on restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      AddData(inputData, 7),
+      AdvanceManualClock(100),
+      CheckNewAnswer(7),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(5).json)
+        inputData.commits.clear()
+      },
+      AddData(inputData, 8),
+      AdvanceManualClock(100),
+      CheckNewAnswer(8),
+      AddData(inputData, 9),
+      AdvanceManualClock(800), //  persist offset
+      CheckNewAnswer(9),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5, 6, 9))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5, 6, 9))
+
+    // simulate batch 9 doesn't have a commit log
+    new File(checkpointLocation + "/commits/9").delete()
+    new File(checkpointLocation + "/commits/.9.crc").delete()
+
+    /**
+     * Turn async progress tracking off
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 10),
+      CheckNewAnswer(7, 8, 9, 10),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(9).json)
+        inputData.commits.clear()
+      },
+      AddData(inputData, 11),
+      CheckNewAnswer(11),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(10).json)
+        inputData.commits.clear()
+      },
+      StopStream
+    )
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5, 6, 9, 10, 11))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5, 6, 9, 10, 11))
+  }
+
+  test("Fail on stateful pipelines") {
+    val rateStream = spark.readStream
+      .format("rate")
+      .option("numPartitions", 1)
+      .option("rowsPerSecond", 10)
+      .load()
+      .toDF()
+
+    val windowedStream = rateStream
+      .withWatermark("timestamp", "0 seconds")
+      .groupBy(window(column("timestamp"), "10 seconds"), column("value"))
+      .count()
+
+    val e = intercept[StreamingQueryException] {
+      val query = windowedStream.writeStream
+        .format("noop")
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+
+      query.processAllAvailable()
+    }
+    e.getMessage should include(
+      "Stateful streaming queries does not support async progress tracking at this moment."
+    )
+  }
+
+  test("Fail on pipelines using unsupported sinks") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("parquet")
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(
+          "checkpointLocation",
+          Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+        )
+        .start("/tmp")
+    }
+
+    e.getMessage should equal("Sink FileSink[/tmp] does not support async progress tracking")
+  }
+
+  test("with log purging") {
+
+    withSQLConf(SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2", SQLConf.ASYNC_LOG_PURGE.key -> "false") {
+      withTempDir { checkpointLocation =>
+        val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+        val ds = inputData.toDS()
+
+        val clock = new StreamManualClock
+        testStream(
+          ds,
+          extraOptions = Map(
+            ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+            ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+          )
+        )(
+          // need to to set processing time to something so manual clock will work
+          StartStream(
+            Trigger.ProcessingTime("1 millisecond"),
+            checkpointLocation = checkpointLocation.getCanonicalPath,
+            triggerClock = clock
+          ),
+          AddData(inputData, 0),
+          AdvanceManualClock(100),
+          CheckNewAnswer(0),
+          AddData(inputData, 1),
+          AdvanceManualClock(100),
+          CheckNewAnswer(1),
+          AddData(inputData, 2),
+          AdvanceManualClock(100),
+          CheckNewAnswer(2),
+          AddData(inputData, 3),
+          AdvanceManualClock(800), // should trigger offset commit write to durable storage
+          CheckNewAnswer(3),
+          AddData(inputData, 4),
+          AdvanceManualClock(100),
+          CheckNewAnswer(4),
+          AddData(inputData, 5),
+          AdvanceManualClock(100),
+          CheckNewAnswer(5),
+          AddData(inputData, 6),
+          AdvanceManualClock(100),
+          CheckNewAnswer(6),
+          AddData(inputData, 7),
+          AdvanceManualClock(800), // should trigger offset commit write to durable storage
+          CheckNewAnswer(7),
+          Execute { q =>
+            // wait for all async log writes to finish
+            waitPendingOffsetWrites(q)
+          },
+          // add a new row to make sure log purge has kicked in.
+          // There can be a race condition in which the commit log entry for the previous batch
+          // may or may not be written to disk yet before the log purge is called.
+          // Adding another batch here will make sure purge is called on the correct number of
+          // offset and commit log entries
+          AddData(inputData, 8),
+          AdvanceManualClock(100),
+          CheckNewAnswer(8),
+          Execute { q =>
+            waitPendingOffsetWrites(q)
+            getListOfFiles(checkpointLocation + "/offsets")
+              .filter(file => !file.isHidden)
+              .map(file => file.getName.toInt)
+              .sorted should equal(Array(3, 7))
+            getListOfFiles(checkpointLocation + "/commits")
+              .filter(file => !file.isHidden)
+              .map(file => file.getName.toInt)
+              .sorted should equal(Array(3, 7))
+            q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].writtenToDurableStorage.size() should be(2)
+            q.commitLog.asInstanceOf[AsyncCommitLog].writtenToDurableStorage.size() should be(2)
+          },
+          StopStream
+        )
+
+        /**
+         * restart
+         */
+        testStream(
+          ds,
+          extraOptions = Map(
+            ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+            ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+          )
+        )(
+          // need to to set processing time to something so manual clock will work
+          StartStream(
+            Trigger.ProcessingTime("1 millisecond"),
+            checkpointLocation = checkpointLocation.getCanonicalPath,
+            triggerClock = clock
+          ),
+          AddData(inputData, 9),
+          AdvanceManualClock(100),
+          CheckNewAnswer(8, 9),
+          AddData(inputData, 10),
+          AdvanceManualClock(100),
+          CheckNewAnswer(10),
+          AddData(inputData, 11),
+          AdvanceManualClock(100),
+          CheckNewAnswer(11),
+          AddData(inputData, 12),
+          AdvanceManualClock(800), // should trigger offset commit write to durable storage
+          CheckNewAnswer(12),
+          Execute { q =>
+            // wait for all async log writes to finish
+            waitPendingOffsetWrites(q)
+          },
+          // add a new row to make sure log purge has kicked in.
+          // There can be a race condition in which the commit log entry for the previous batch
+          // may or may not be written to disk yet before the log purge is called.
+          // Adding another batch here will make sure purge is called on the correct number of
+          // offset and commit log entries
+          AddData(inputData, 13),
+          AdvanceManualClock(100),
+          CheckNewAnswer(13),
+          Execute { q =>
+            waitPendingOffsetWrites(q)
+            getListOfFiles(checkpointLocation + "/offsets")
+              .filter(file => !file.isHidden)
+              .map(file => file.getName.toInt)
+              .sorted should equal(Array(8, 12))
+            getListOfFiles(checkpointLocation + "/commits")
+              .filter(file => !file.isHidden)
+              .map(file => file.getName.toInt)
+              .sorted should equal(Array(8, 12))
+
+            q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].writtenToDurableStorage.size() should be(2)
+            q.commitLog.asInstanceOf[AsyncCommitLog].writtenToDurableStorage.size() should be(2)
+          },
+          StopStream
+        )
+      }
+    }
+  }
+
+  test("with async log purging") {
+    withSQLConf(SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2", SQLConf.ASYNC_LOG_PURGE.key -> "true") {
+      withTempDir { checkpointLocation =>
+        val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+        val ds = inputData.toDS()
+
+        val clock = new StreamManualClock
+        testStream(
+          ds,
+          extraOptions = Map(
+            ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+            ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+          )
+        )(
+          // need to to set processing time to something so manual clock will work
+          StartStream(
+            Trigger.ProcessingTime("1 millisecond"),
+            checkpointLocation = checkpointLocation.getCanonicalPath,
+            triggerClock = clock
+          ),
+          AddData(inputData, 0),
+          AdvanceManualClock(100),
+          CheckNewAnswer(0),
+          AddData(inputData, 1),
+          AdvanceManualClock(100),
+          CheckNewAnswer(1),
+          AddData(inputData, 2),
+          AdvanceManualClock(100),
+          CheckNewAnswer(2),
+          AddData(inputData, 3),
+          AdvanceManualClock(800), // should trigger offset commit write to durable storage
+          Execute { q =>
+            // wait for async log writes to complete
+            waitPendingOffsetWrites(q)
+            eventually(timeout(Span(5, Seconds))) {
+              getListOfFiles(checkpointLocation + "/offsets")
+                .filter(file => !file.isHidden)
+                .map(file => file.getName.toInt)
+                .sorted should equal(Array(0, 3))
+
+              getListOfFiles(checkpointLocation + "/commits")
+                .filter(file => !file.isHidden)
+                .map(file => file.getName.toInt)
+                .sorted should equal(Array(0, 3))
+            }
+          },
+          CheckNewAnswer(3),
+          AddData(inputData, 4),
+          AdvanceManualClock(100),
+          CheckNewAnswer(4),
+          AddData(inputData, 5),
+          AdvanceManualClock(100),
+          CheckNewAnswer(5),
+          AddData(inputData, 6),
+          AdvanceManualClock(100),
+          CheckNewAnswer(6),
+          AddData(inputData, 7),
+          AdvanceManualClock(800), // should trigger offset commit write to durable storage
+          CheckNewAnswer(7),
+          Execute { q =>
+            // wait for async log writes to complete
+            waitPendingOffsetWrites(q)
+            // can contain batches 0, 3, 7 or 3, 7

Review Comment:
   no guarantee at this point as the timing is hard to control



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049243767


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())
+            if (prevBatchOff.isDefined) {
+              // Offset is ready to be committed by the source. Add to queue
+              sourceCommitQueue.add(prevBatchOff.get)
+            } else {
+              throw new IllegalStateException(
+                s"batch ${lastBatchPersistedToDurableStorage.get()} doesn't exist"
+              )
+            }
+          }
+          lastBatchPersistedToDurableStorage.set(batchId)
+        }
+      })
+      .exceptionally((th: Throwable) => {
+        logError("Encountered error while performing async offset write", th)
+        errorNotifier.markError(th)
+        return
+      })
+
+    // check if there are offsets that are ready to be committed by the source
+    var offset = sourceCommitQueue.poll()
+    while (offset != null) {
+      commitSources(offset)
+      offset = sourceCommitQueue.poll()
+    }
+  }
+
+  override def markMicroBatchEnd(): Unit = {
+    watermarkTracker.updateWatermark(lastExecution.executedPlan)
+    reportTimeTaken("commitOffsets") {
+      // check if current batch there is a async write for the offset log is issued for this batch
+      // if so, we should do the same for commit log
+      if (!offsetLog.getAsyncOffsetWrite(currentBatchId).isEmpty) {
+        commitLog
+          .addAsync(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))
+          .exceptionally((th: Throwable) => {
+            logError("Got exception during async write", th)
+            errorNotifier.markError(th)
+            return
+          })
+      } else {
+        if (!commitLog.addInMemory(
+          currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))) {
+          throw new IllegalStateException(
+            s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId"
+          )
+        }
+      }
+      offsetLog.removeAsyncOffsetWrite(currentBatchId)
+    }
+    committedOffsets ++= availableOffsets
+  }
+
+  // need to look at the number of files on disk
+  override def purge(threshold: Long): Unit = {
+    while (offsetLog.writtenToDurableStorage.size() > minLogEntriesToMaintain) {
+      offsetLog.writtenToDurableStorage.poll()
+    }
+    offsetLog.purge(offsetLog.writtenToDurableStorage.peek())
+
+    while (commitLog.writtenToDurableStorage.size() > minLogEntriesToMaintain) {
+      commitLog.writtenToDurableStorage.poll()
+    }
+    commitLog.purge(commitLog.writtenToDurableStorage.peek())
+  }
+
+  override def cleanup(): Unit = {
+    super.cleanup()
+
+    ThreadUtils.shutdown(asyncWritesExecutorService)
+    logInfo(s"Async progress tracking executor pool for query ${prettyIdString} has been shutdown")
+  }
+
+  // used for testing
+  def areWritesPendingOrInProgress(): Boolean = {
+    asyncWritesExecutorService.getQueue.size() > 0 || asyncWritesExecutorService.getActiveCount > 0
+  }
+
+  private def validateAndGetTrigger(): TriggerExecutor = {
+    // validate that the pipeline is using a supported sink
+    if (!extraOptions
+      .get(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK
+      )
+      .getOrElse("false")
+      .toBoolean) {
+      try {
+        plan.sink.name() match {
+          case "noop-table" =>
+          case "console" =>
+          case "MemorySink" =>
+          case "KafkaTable" =>
+          case _ =>
+            throw new IllegalArgumentException(
+              s"Sink ${plan.sink.name()}" +
+                s" does not support async progress tracking"
+            )
+        }
+      } catch {
+        case e: IllegalStateException =>
+          // sink does not implement name() method
+          if (e.getMessage.equals("should not be called.")) {

Review Comment:
   The exception message will be confusing. It will be IllegalStateException("should not be called.")



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049243767


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())
+            if (prevBatchOff.isDefined) {
+              // Offset is ready to be committed by the source. Add to queue
+              sourceCommitQueue.add(prevBatchOff.get)
+            } else {
+              throw new IllegalStateException(
+                s"batch ${lastBatchPersistedToDurableStorage.get()} doesn't exist"
+              )
+            }
+          }
+          lastBatchPersistedToDurableStorage.set(batchId)
+        }
+      })
+      .exceptionally((th: Throwable) => {
+        logError("Encountered error while performing async offset write", th)
+        errorNotifier.markError(th)
+        return
+      })
+
+    // check if there are offsets that are ready to be committed by the source
+    var offset = sourceCommitQueue.poll()
+    while (offset != null) {
+      commitSources(offset)
+      offset = sourceCommitQueue.poll()
+    }
+  }
+
+  override def markMicroBatchEnd(): Unit = {
+    watermarkTracker.updateWatermark(lastExecution.executedPlan)
+    reportTimeTaken("commitOffsets") {
+      // check if current batch there is a async write for the offset log is issued for this batch
+      // if so, we should do the same for commit log
+      if (!offsetLog.getAsyncOffsetWrite(currentBatchId).isEmpty) {
+        commitLog
+          .addAsync(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))
+          .exceptionally((th: Throwable) => {
+            logError("Got exception during async write", th)
+            errorNotifier.markError(th)
+            return
+          })
+      } else {
+        if (!commitLog.addInMemory(
+          currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))) {
+          throw new IllegalStateException(
+            s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId"
+          )
+        }
+      }
+      offsetLog.removeAsyncOffsetWrite(currentBatchId)
+    }
+    committedOffsets ++= availableOffsets
+  }
+
+  // need to look at the number of files on disk
+  override def purge(threshold: Long): Unit = {
+    while (offsetLog.writtenToDurableStorage.size() > minLogEntriesToMaintain) {
+      offsetLog.writtenToDurableStorage.poll()
+    }
+    offsetLog.purge(offsetLog.writtenToDurableStorage.peek())
+
+    while (commitLog.writtenToDurableStorage.size() > minLogEntriesToMaintain) {
+      commitLog.writtenToDurableStorage.poll()
+    }
+    commitLog.purge(commitLog.writtenToDurableStorage.peek())
+  }
+
+  override def cleanup(): Unit = {
+    super.cleanup()
+
+    ThreadUtils.shutdown(asyncWritesExecutorService)
+    logInfo(s"Async progress tracking executor pool for query ${prettyIdString} has been shutdown")
+  }
+
+  // used for testing
+  def areWritesPendingOrInProgress(): Boolean = {
+    asyncWritesExecutorService.getQueue.size() > 0 || asyncWritesExecutorService.getActiveCount > 0
+  }
+
+  private def validateAndGetTrigger(): TriggerExecutor = {
+    // validate that the pipeline is using a supported sink
+    if (!extraOptions
+      .get(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK
+      )
+      .getOrElse("false")
+      .toBoolean) {
+      try {
+        plan.sink.name() match {
+          case "noop-table" =>
+          case "console" =>
+          case "MemorySink" =>
+          case "KafkaTable" =>
+          case _ =>
+            throw new IllegalArgumentException(
+              s"Sink ${plan.sink.name()}" +
+                s" does not support async progress tracking"
+            )
+        }
+      } catch {
+        case e: IllegalStateException =>
+          // sink does not implement name() method
+          if (e.getMessage.equals("should not be called.")) {

Review Comment:
   The exception message will be confusing



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049092707


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/offsets/3").delete()
+    new File(checkpointLocation + "/offsets/.3.crc").delete()
+    new File(checkpointLocation + "/offsets/4").delete()
+    new File(checkpointLocation + "/offsets/.4.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.4.crc").delete()
+    new File(checkpointLocation + "/commits/4").delete()
+    new File(checkpointLocation + "/commits/.5.crc").delete()
+    new File(checkpointLocation + "/commits/5").delete()
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      AddData(inputData2, 3),
+      AddData(inputData2, 4),
+      AddData(inputData2, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3, 4, 5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+  }
+
+  test(
+    "recovery when gaps in in offset and" +
+    " commit log"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/commits/2").delete()
+    new File(checkpointLocation + "/commits/.2.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from batch 1
+      CheckNewAnswer(1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 3, 4, 5))
+  }
+
+  test("recovery non-contiguous log") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // delete batch 3 from commit log
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * Turn async progress tracking off and test recovery
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 4),
+      CheckNewAnswer(1, 2, 3, 4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 5))
+
+  }
+
+  test("switching async progress tracking with interval commits on and off") {
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    var clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    inputData.commits.length should be(0)
+
+    /**
+     * Turn async progress tracking off
+     */
+    testStream(ds)(
+      // need to to set processing time to something so manual clock will work
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(3).json)
+        inputData.commits.clear()
+      },
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+      Execute { q =>
+        inputData.commits.length should equal(1)
+        inputData.commits(0).json should equal(LongOffset(4).json)
+        inputData.commits.clear()
+      },
+      StopStream
+    )
+    // batches 0 and 3 should be logged

Review Comment:
   will remove



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049119941


##########
connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala:
##########
@@ -195,6 +200,102 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase {
     true
   }
 
+  /**
+   * Test async progress tracking capability with Kafka source and sink
+   */
+  test("async progress tracking") {
+    val inputTopic = newTopic()
+    testUtils.createTopic(inputTopic, partitions = 5)
+
+    val dataSent = new ListBuffer[String]()
+    testUtils.sendMessages(inputTopic, (0 until 15).map { case x =>
+      val m = s"foo-$x"
+      dataSent += m
+      m
+    }.toArray, Some(0))
+
+    val outputTopic = newTopic()
+    testUtils.createTopic(outputTopic, partitions = 5)
+
+    withTempDir { dir =>
+      val reader = spark
+        .readStream
+        .format("kafka")
+        .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+        .option("kafka.metadata.max.age.ms", "1")
+        .option("maxOffsetsPerTrigger", 5)
+        .option("subscribe", inputTopic)
+        .option("startingOffsets", "earliest")
+        .load()
+
+      def startQuery(): StreamingQuery = {
+        reader.writeStream
+          .format("kafka")
+          .option("checkpointLocation", dir.getCanonicalPath)
+          .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+          .option("kafka.max.block.ms", "5000")
+          .option("topic", outputTopic)
+          .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+          .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+          .queryName("kafkaStream")
+          .start()
+      }
+
+      def readResults(): ListBuffer[String] = {

Review Comment:
   The output would be same but the code and actual execution would be much simpler in batch query. See below code when we just go with batch query:
   
   ```
   spark.read
     .format("kafka")
     .option("kafka.bootstrap.servers", testUtils.brokerAddress)
     .option("startingOffsets", "earliest")
     .option("subscribe", outputTopic)
     .load()
     .select("CAST(value AS string)")
     .toDS()
     .collect()
     .map(_._1)
   ```
   
   The entire code in the method can be replaced with this query. Haven't gave a try but the actual code that could execute won't be much different.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048920790


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncCommitLog.scala:
##########
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.OutputStream
+import java.util.concurrent.{CompletableFuture, ConcurrentLinkedDeque, ThreadPoolExecutor}
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.sql.SparkSession
+
+/**
+ * Implementation of CommitLog to perform asynchronous writes to storage
+ */
+class AsyncCommitLog(sparkSession: SparkSession, path: String, executorService: ThreadPoolExecutor)
+    extends CommitLog(sparkSession, path) {
+
+  // A queue of batches written to storage.  Used to keep track when to purge old batches
+  val writtenToDurableStorage =
+    new ConcurrentLinkedDeque[Long](listBatchesOnDisk.toList.asJavaCollection)
+
+  /**
+   * Writes a new batch to the commit log asynchronously
+   * @param batchId id of batch to write
+   * @param metadata metadata of batch to write
+   * @return a CompeletableFuture that contains the batch id.  The future is completed when
+   *         the async write of the batch is completed.  Future may also be completed exceptionally
+   *         to indicate some write error.
+   */
+  def addAsync(batchId: Long, metadata: CommitMetadata): CompletableFuture[Long] = {
+    require(metadata != null, "'null' metadata cannot be written to a metadata log")
+    val future: CompletableFuture[Long] = addNewBatchByStreamAsync(batchId) { output =>
+      serialize(metadata, output)
+    }.thenApply((ret: Boolean) => {
+      if (ret) {
+        batchId
+      } else {
+        throw new IllegalStateException(
+          s"Concurrent update to the log. Multiple streaming jobs detected for $batchId"

Review Comment:
   This is the message in existing implementation.  We don't really differentiate which log.  However, you should know which log it is based on the stack trace.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048934596


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,

Review Comment:
   ok



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049027681


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)

Review Comment:
   why?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049026257


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {

Review Comment:
   will fix



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049075469


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/offsets/3").delete()
+    new File(checkpointLocation + "/offsets/.3.crc").delete()
+    new File(checkpointLocation + "/offsets/4").delete()
+    new File(checkpointLocation + "/offsets/.4.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.4.crc").delete()
+    new File(checkpointLocation + "/commits/4").delete()
+    new File(checkpointLocation + "/commits/.5.crc").delete()
+    new File(checkpointLocation + "/commits/5").delete()
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      AddData(inputData2, 3),
+      AddData(inputData2, 4),
+      AddData(inputData2, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3, 4, 5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+  }
+
+  test(
+    "recovery when gaps in in offset and" +

Review Comment:
   sure



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049073555


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/offsets/3").delete()
+    new File(checkpointLocation + "/offsets/.3.crc").delete()
+    new File(checkpointLocation + "/offsets/4").delete()
+    new File(checkpointLocation + "/offsets/.4.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.4.crc").delete()
+    new File(checkpointLocation + "/commits/4").delete()
+    new File(checkpointLocation + "/commits/.5.crc").delete()
+    new File(checkpointLocation + "/commits/5").delete()
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      AddData(inputData2, 3),
+      AddData(inputData2, 4),
+      AddData(inputData2, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3, 4, 5),

Review Comment:
   will add comment
   
   > And would the behavior be same between sync and async progress tracking?
   
   Yes



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048788663


##########
connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala:
##########
@@ -195,6 +200,102 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase {
     true
   }
 
+  /**
+   * Test async progress tracking capability with Kafka source and sink
+   */
+  test("async progress tracking") {
+    val inputTopic = newTopic()
+    testUtils.createTopic(inputTopic, partitions = 5)
+
+    val dataSent = new ListBuffer[String]()
+    testUtils.sendMessages(inputTopic, (0 until 15).map { case x =>
+      val m = s"foo-$x"
+      dataSent += m
+      m
+    }.toArray, Some(0))
+
+    val outputTopic = newTopic()
+    testUtils.createTopic(outputTopic, partitions = 5)
+
+    withTempDir { dir =>
+      val reader = spark
+        .readStream
+        .format("kafka")
+        .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+        .option("kafka.metadata.max.age.ms", "1")
+        .option("maxOffsetsPerTrigger", 5)
+        .option("subscribe", inputTopic)
+        .option("startingOffsets", "earliest")
+        .load()
+
+      def startQuery(): StreamingQuery = {
+        reader.writeStream
+          .format("kafka")
+          .option("checkpointLocation", dir.getCanonicalPath)
+          .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+          .option("kafka.max.block.ms", "5000")
+          .option("topic", outputTopic)
+          .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+          .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+          .queryName("kafkaStream")
+          .start()
+      }
+
+      def readResults(): ListBuffer[String] = {

Review Comment:
   Ok if it works all the same, lets just keep what it is currently.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048939040


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful

Review Comment:
   ok



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048935970


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {

Review Comment:
   ok



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1048957823


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())
+            if (prevBatchOff.isDefined) {
+              // Offset is ready to be committed by the source. Add to queue
+              sourceCommitQueue.add(prevBatchOff.get)
+            } else {
+              throw new IllegalStateException(
+                s"batch ${lastBatchPersistedToDurableStorage.get()} doesn't exist"
+              )
+            }
+          }
+          lastBatchPersistedToDurableStorage.set(batchId)
+        }
+      })
+      .exceptionally((th: Throwable) => {
+        logError("Encountered error while performing async offset write", th)
+        errorNotifier.markError(th)
+        return
+      })
+
+    // check if there are offsets that are ready to be committed by the source
+    var offset = sourceCommitQueue.poll()
+    while (offset != null) {
+      commitSources(offset)
+      offset = sourceCommitQueue.poll()
+    }
+  }
+
+  override def markMicroBatchEnd(): Unit = {
+    watermarkTracker.updateWatermark(lastExecution.executedPlan)
+    reportTimeTaken("commitOffsets") {
+      // check if current batch there is a async write for the offset log is issued for this batch
+      // if so, we should do the same for commit log
+      if (!offsetLog.getAsyncOffsetWrite(currentBatchId).isEmpty) {
+        commitLog
+          .addAsync(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))
+          .exceptionally((th: Throwable) => {
+            logError("Got exception during async write", th)
+            errorNotifier.markError(th)
+            return
+          })
+      } else {
+        if (!commitLog.addInMemory(
+          currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))) {
+          throw new IllegalStateException(
+            s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId"

Review Comment:
   This is the current logic that we have in MicroBatchExecution



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecution.scala:
##########
@@ -0,0 +1,282 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.streaming.Trigger
+import org.apache.spark.util.{Clock, ThreadUtils}
+
+object AsyncProgressTrackingMicroBatchExecution {
+  val ASYNC_PROGRESS_TRACKING_ENABLED = "asyncProgressTrackingEnabled"
+  val ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS =
+    "asyncProgressTrackingCheckpointIntervalMs"
+
+  // for testing purposes
+  val ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK =
+    "_asyncProgressTrackingOverrideSinkSupportCheck"
+
+  private def getAsyncProgressTrackingCheckpointingIntervalMs(
+      extraOptions: Map[String, String]): Long = {
+    extraOptions
+      .getOrElse(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS,
+        "1000"
+      )
+      .toLong
+  }
+}
+
+/**
+ * Class to execute micro-batches when async progress tracking is enabled
+ */
+class AsyncProgressTrackingMicroBatchExecution(
+    sparkSession: SparkSession,
+    trigger: Trigger,
+    triggerClock: Clock,
+    extraOptions: Map[String, String],
+    plan: WriteToStream)
+    extends MicroBatchExecution(sparkSession, trigger, triggerClock, extraOptions, plan) {
+
+  protected val asyncProgressTrackingCheckpointingIntervalMs: Long
+  = AsyncProgressTrackingMicroBatchExecution
+    .getAsyncProgressTrackingCheckpointingIntervalMs(extraOptions)
+
+  // Offsets that are ready to be committed by the source.
+  // This is needed so that we can call source commit in the same thread as micro-batch execution
+  // to be thread safe
+  private val sourceCommitQueue = new ConcurrentLinkedQueue[OffsetSeq]()
+
+  // to cache the batch id of the last batch written to storage
+  private val lastBatchPersistedToDurableStorage = new AtomicLong(-1)
+
+  override val triggerExecutor: TriggerExecutor = validateAndGetTrigger()
+
+  private var isFirstBatch: Boolean = true
+
+  // thread pool is only one thread because we want offset
+  // writes to execute in order in a serialized fashion
+  protected val asyncWritesExecutorService
+  = ThreadUtils.newDaemonSingleThreadExecutorWithRejectedExecutionHandler(
+    "async-log-write",
+    2, // one for offset commit and one for completion commit
+    new RejectedExecutionHandler() {
+      override def rejectedExecution(r: Runnable, executor: ThreadPoolExecutor): Unit = {
+        try {
+          if (!executor.isShutdown) {
+            val start = System.currentTimeMillis()
+            executor.getQueue.put(r)
+            logDebug(
+              s"Async write paused execution for " +
+                s"${System.currentTimeMillis() - start} due to task queue being full."
+            )
+          }
+        } catch {
+          case e: InterruptedException =>
+            Thread.currentThread.interrupt()
+            throw new RejectedExecutionException("Producer interrupted", e)
+          case e: Throwable =>
+            logError("Encountered error in async write executor service", e)
+            errorNotifier.markError(e)
+        }
+      }
+    })
+
+  override val offsetLog = new AsyncOffsetSeqLog(
+    sparkSession,
+    checkpointFile("offsets"),
+    asyncWritesExecutorService,
+    asyncProgressTrackingCheckpointingIntervalMs,
+    clock = triggerClock
+  )
+
+  override val commitLog =
+    new AsyncCommitLog(sparkSession, checkpointFile("commits"), asyncWritesExecutorService)
+
+  override def markMicroBatchExecutionStart(): Unit = {
+    // check if pipeline is stateful
+    checkNotStatefulPipeline
+  }
+
+  override def cleanUpLastExecutedMicroBatch(): Unit = {
+    // this is a no op for async progress tracking since we only want to commit sources only
+    // after the offset WAL commit has be successfully written
+  }
+
+  /**
+   * Should not call super method as we need to do something completely different
+   * in this method for async progress tracking
+   */
+  override def markMicroBatchStart(): Unit = {
+    // Because we are using a thread pool with only one thread, async writes to the offset log
+    // are still written in a serial / in order fashion
+    offsetLog
+      .addAsync(currentBatchId, availableOffsets.toOffsetSeq(sources, offsetSeqMetadata))
+      .thenAccept(tuple => {
+        val (batchId, persistedToDurableStorage) = tuple
+        if (persistedToDurableStorage) {
+
+          // batch id cache not initialized
+          if (lastBatchPersistedToDurableStorage.get == -1) {
+            lastBatchPersistedToDurableStorage.set(
+              offsetLog.getPrevBatchFromStorage(batchId).getOrElse(-1))
+          }
+
+          if (batchId != 0 && lastBatchPersistedToDurableStorage.get != -1) {
+            // sanity check to make sure batch ids are monotonically increasing
+            assert(lastBatchPersistedToDurableStorage.get < batchId)
+            val prevBatchOff = offsetLog.get(lastBatchPersistedToDurableStorage.get())
+            if (prevBatchOff.isDefined) {
+              // Offset is ready to be committed by the source. Add to queue
+              sourceCommitQueue.add(prevBatchOff.get)
+            } else {
+              throw new IllegalStateException(
+                s"batch ${lastBatchPersistedToDurableStorage.get()} doesn't exist"
+              )
+            }
+          }
+          lastBatchPersistedToDurableStorage.set(batchId)
+        }
+      })
+      .exceptionally((th: Throwable) => {
+        logError("Encountered error while performing async offset write", th)
+        errorNotifier.markError(th)
+        return
+      })
+
+    // check if there are offsets that are ready to be committed by the source
+    var offset = sourceCommitQueue.poll()
+    while (offset != null) {
+      commitSources(offset)
+      offset = sourceCommitQueue.poll()
+    }
+  }
+
+  override def markMicroBatchEnd(): Unit = {
+    watermarkTracker.updateWatermark(lastExecution.executedPlan)
+    reportTimeTaken("commitOffsets") {
+      // check if current batch there is a async write for the offset log is issued for this batch
+      // if so, we should do the same for commit log
+      if (!offsetLog.getAsyncOffsetWrite(currentBatchId).isEmpty) {
+        commitLog
+          .addAsync(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))
+          .exceptionally((th: Throwable) => {
+            logError("Got exception during async write", th)
+            errorNotifier.markError(th)
+            return
+          })
+      } else {
+        if (!commitLog.addInMemory(
+          currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))) {
+          throw new IllegalStateException(
+            s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId"
+          )
+        }
+      }
+      offsetLog.removeAsyncOffsetWrite(currentBatchId)
+    }
+    committedOffsets ++= availableOffsets
+  }
+
+  // need to look at the number of files on disk
+  override def purge(threshold: Long): Unit = {
+    while (offsetLog.writtenToDurableStorage.size() > minLogEntriesToMaintain) {
+      offsetLog.writtenToDurableStorage.poll()
+    }
+    offsetLog.purge(offsetLog.writtenToDurableStorage.peek())
+
+    while (commitLog.writtenToDurableStorage.size() > minLogEntriesToMaintain) {
+      commitLog.writtenToDurableStorage.poll()
+    }
+    commitLog.purge(commitLog.writtenToDurableStorage.peek())
+  }
+
+  override def cleanup(): Unit = {
+    super.cleanup()
+
+    ThreadUtils.shutdown(asyncWritesExecutorService)
+    logInfo(s"Async progress tracking executor pool for query ${prettyIdString} has been shutdown")
+  }
+
+  // used for testing
+  def areWritesPendingOrInProgress(): Boolean = {
+    asyncWritesExecutorService.getQueue.size() > 0 || asyncWritesExecutorService.getActiveCount > 0
+  }
+
+  private def validateAndGetTrigger(): TriggerExecutor = {
+    // validate that the pipeline is using a supported sink
+    if (!extraOptions
+      .get(
+        AsyncProgressTrackingMicroBatchExecution.ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK

Review Comment:
   sure



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049012172


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala:
##########
@@ -727,18 +719,56 @@ class MicroBatchExecution(
 
     withProgressLocked {
       sinkCommitProgress = batchSinkProgress
-      watermarkTracker.updateWatermark(lastExecution.executedPlan)
-      reportTimeTaken("commitOffsets") {
-        assert(commitLog.add(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark)),
-          "Concurrent update to the commit log. Multiple streaming jobs detected for " +
-            s"$currentBatchId")
-      }
-      committedOffsets ++= availableOffsets
+      markMicroBatchEnd()
     }
     logDebug(s"Completed batch ${currentBatchId}")
   }
 
-  /** Execute a function while locking the stream from making an progress */
+
+  /**
+   * Called at the start of the micro batch with given offsets. It takes care of offset
+   * checkpointing to offset log and any microbatch startup tasks.
+   */
+  protected def markMicroBatchStart(): Unit = {
+    assert(offsetLog.add(currentBatchId,
+      availableOffsets.toOffsetSeq(sources, offsetSeqMetadata)),
+      s"Concurrent update to the log. Multiple streaming jobs detected for $currentBatchId")
+    logInfo(s"Committed offsets for batch $currentBatchId. " +
+      s"Metadata ${offsetSeqMetadata.toString}")
+  }
+
+  /**
+   * Method called once after the planning is done and before the start of the microbatch execution.
+   * It can be used to perform any pre-execution tasks.
+   */
+  protected def markMicroBatchExecutionStart(): Unit = {}
+
+  /**
+   * Called after the microbatch has completed execution. It takes care of committing the offset
+   * to commit log and other bookkeeping.
+   */
+  protected def markMicroBatchEnd(): Unit = {
+    watermarkTracker.updateWatermark(lastExecution.executedPlan)
+    reportTimeTaken("commitOffsets") {
+      assert(commitLog.add(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark)),
+        "Concurrent update to the commit log. Multiple streaming jobs detected for " +
+          s"$currentBatchId")
+    }
+    committedOffsets ++= availableOffsets
+  }
+
+  protected def cleanUpLastExecutedMicroBatch(): Unit = {
+    if (currentBatchId != 0) {
+      val prevBatchOff = offsetLog.get(currentBatchId - 1)
+      if (prevBatchOff.isDefined) {
+        commitSources(prevBatchOff.get)
+      } else {
+        throw new IllegalStateException(s"batch ${currentBatchId - 1} doesn't exist")
+      }
+    }
+  }
+
+    /** Execute a function while locking the stream from making an progress */

Review Comment:
   will fix



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] jerrypeng commented on a diff in pull request #38517: [SPARK-39591][SS] Async Progress Tracking

Posted by GitBox <gi...@apache.org>.
jerrypeng commented on code in PR #38517:
URL: https://github.com/apache/spark/pull/38517#discussion_r1049076380


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/offsets/3").delete()
+    new File(checkpointLocation + "/offsets/.3.crc").delete()
+    new File(checkpointLocation + "/offsets/4").delete()
+    new File(checkpointLocation + "/offsets/.4.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.4.crc").delete()
+    new File(checkpointLocation + "/commits/4").delete()
+    new File(checkpointLocation + "/commits/.5.crc").delete()
+    new File(checkpointLocation + "/commits/5").delete()
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      AddData(inputData2, 3),
+      AddData(inputData2, 4),
+      AddData(inputData2, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3, 4, 5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+  }
+
+  test(
+    "recovery when gaps in in offset and" +
+    " commit log"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/commits/2").delete()
+    new File(checkpointLocation + "/commits/.2.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from batch 1

Review Comment:
   The offset log contains 0, 2 and commit log 0.  Thus, on a restart, offsets 1-2 will be processed



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala:
##########
@@ -0,0 +1,1865 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import java.io.{File, OutputStream}
+import java.util.concurrent.{CountDownLatch, Semaphore, ThreadPoolExecutor, TimeUnit}
+
+import scala.collection.mutable.ListBuffer
+
+import org.apache.hadoop.fs.Path
+import org.scalatest.BeforeAndAfter
+import org.scalatest.matchers.should.Matchers
+import org.scalatest.time.{Seconds, Span}
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.streaming.WriteToStream
+import org.apache.spark.sql.connector.read.streaming
+import org.apache.spark.sql.execution.streaming.AsyncProgressTrackingMicroBatchExecution.{ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, ASYNC_PROGRESS_TRACKING_ENABLED, ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK}
+import org.apache.spark.sql.functions.{column, window}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryException, StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.util.{Clock, Utils}
+
+class AsyncProgressTrackingMicroBatchExecutionSuite
+    extends StreamTest
+    with BeforeAndAfter
+    with Matchers {
+
+  import testImplicits._
+
+  after {
+    sqlContext.streams.active.foreach(_.stop())
+  }
+
+  def getListOfFiles(dir: String): List[File] = {
+    val d = new File(dir)
+    if (d.exists && d.isDirectory) {
+      d.listFiles.filter(_.isFile).toList
+    } else {
+      List[File]()
+    }
+  }
+
+  def waitPendingOffsetWrites(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .areWritesPendingOrInProgress() should be(false)
+    }
+  }
+
+  def waitPendingPurges(streamExecution: StreamExecution): Unit = {
+    assert(streamExecution.isInstanceOf[AsyncProgressTrackingMicroBatchExecution])
+    eventually(timeout(Span(5, Seconds))) {
+      streamExecution
+        .asInstanceOf[AsyncProgressTrackingMicroBatchExecution]
+        .arePendingAsyncPurge should be(false)
+    }
+  }
+
+  // test the basic functionality i.e. happy path
+  test("async WAL commits happy path") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val tableName = "test"
+
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .format("memory")
+        .queryName(tableName)
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Row]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += Row(v)
+      }
+      query.processAllAvailable()
+    }
+
+    checkAnswer(
+      spark.table(tableName),
+      expected.toSeq
+    )
+  }
+
+  test("async WAL commits recovery") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    // to synchronize producing and consuming messages so that
+    // we can generate and read the desired number of batches
+    var countDownLatch = new CountDownLatch(10)
+    val sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    var query = startQuery()
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    assert(index == 10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    countDownLatch = new CountDownLatch(10)
+
+    /**
+     * Start the query again
+     */
+    query = startQuery()
+
+    for (i <- 10 until 20) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data.toSet should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19).toSet
+    )
+  }
+
+  test("async WAL commits turn on and off") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      AddData(inputData, 0),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckAnswer(0),
+      AddData(inputData, 1),
+      CheckAnswer(0, 1),
+      AddData(inputData, 2),
+      CheckAnswer(0, 1, 2),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2))
+
+    /**
+     * Starting stream second time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 3),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+    // commits for batch 2, 3, 4 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4))
+
+    /**
+     * Starting stream third time with async progress tracking turned back on
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      // add new data
+      AddData(inputData, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      // no data needs to be replayed because commit log is on previously
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      CheckNewAnswer(6),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // make sure we have removed all pending commits
+        q.offsetLog.asInstanceOf[AsyncOffsetSeqLog].pendingAsyncOffsetWrite() should be(0)
+      },
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+    // no new commits should be logged since async offset commits are enabled
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6))
+
+    /**
+     * Starting stream fourth time with async progress tracking turned off
+     */
+    testStream(ds)(
+      // add new data
+      AddData(inputData, 7),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(7),
+      AddData(inputData, 8),
+      CheckNewAnswer(8),
+      StopStream
+    )
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+    // commits for batch 2, 3, 4, 6, 7, 8 should be logged
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8))
+  }
+
+  test("Fail with once trigger") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.Once())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with Once trigger")
+  }
+
+  test("Fail with available now trigger") {
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    val e = intercept[IllegalArgumentException] {
+      ds.writeStream
+        .format("noop")
+        .trigger(Trigger.AvailableNow())
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .start()
+    }
+    e.getMessage should equal("Async progress tracking cannot be used with AvailableNow trigger")
+  }
+
+  test("switching between async wal commit enabled and trigger once") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncProgressTracking: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncProgressTracking)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async progress tracking turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    // offsets should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger once
+     */
+
+    // trigger once should process batch 10
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit again
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.Once())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // convert data to set to deduplicate results
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39)
+    )
+
+    // batch 21 should be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+
+  }
+
+  test("switching between async wal commit enabled and available now") {
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDF()
+
+    var index = 0
+    var countDownLatch = new CountDownLatch(10)
+    var sem = new Semaphore(1)
+    val data = new ListBuffer[Int]()
+    def startQuery(
+        asyncOffsetCommitsEnabled: Boolean,
+        trigger: Trigger = Trigger.ProcessingTime(0)): StreamingQuery = {
+      ds.writeStream
+        .trigger(trigger)
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+          countDownLatch.countDown()
+          index += 1
+          sem.release()
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, asyncOffsetCommitsEnabled)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 0)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+
+    /*
+     start the query with async offset commits turned on
+     */
+    var query = startQuery(true)
+
+    for (i <- 0 until 10) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(9)
+      }
+    } finally {
+      query.stop()
+    }
+
+    index should equal(10)
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    countDownLatch = new CountDownLatch(1)
+    for (i <- 10 until 20) {
+      inputData.addData({ i })
+    }
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 10 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(10)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19))
+
+    // since using trigger available now, the new data, i.e. batch 10, should also be processed
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+
+    /*
+     Turn on async offset commit again
+     */
+
+    countDownLatch = new CountDownLatch(10)
+    sem = new Semaphore(1)
+    query = startQuery(true)
+    for (i <- 20 until 30) {
+      sem.acquire()
+      inputData.addData({ i })
+    }
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(20)
+      }
+    } finally {
+      query.stop()
+    }
+
+    data should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+        24, 25, 26, 27, 28, 29)
+    )
+
+    // 10 more batches should logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+    // no additional commit log entries should be logged since async offset commit is on
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
+    )
+
+    /*
+     Turn off async offset commit and use trigger available now
+     */
+
+    for (i <- 30 until 40) {
+      inputData.addData({ i })
+    }
+
+    countDownLatch = new CountDownLatch(1)
+    query = startQuery(false, trigger = Trigger.AvailableNow())
+
+    try {
+      countDownLatch.await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS)
+      // make sure the batch 20 in the commit log writes to durable storage
+      eventually(timeout(Span(5, Seconds))) {
+        val files = getListOfFiles(checkpointLocation + "/commits")
+          .filter(file => !file.isHidden)
+          .map(file => file.getName.toInt)
+          .sorted
+
+        files.last should be(21)
+      }
+    } finally {
+      query.stop()
+    }
+
+    // just reprocessing batch 20 should not more offset log entries should added
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+    // batch 20 should be added to the commit log
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(
+      Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)
+    )
+  }
+
+  class FailAsyncProgressTrackingMicroBatchExecution(
+      sparkSession: SparkSession,
+      trigger: Trigger,
+      triggerClock: Clock,
+      extraOptions: Map[String, String],
+      plan: WriteToStream)
+      extends AsyncProgressTrackingMicroBatchExecution(
+        sparkSession,
+        trigger,
+        triggerClock,
+        extraOptions,
+        plan
+      ) {
+
+    override val offsetLog = new FailAsyncOffsetSeqLog(
+      sparkSession,
+      checkpointFile("offsets"),
+      asyncWritesExecutorService,
+      asyncProgressTrackingCheckpointingIntervalMs
+    )
+  }
+
+  class FailAsyncOffsetSeqLog(
+      sparkSession: SparkSession,
+      path: String,
+      executorService: ThreadPoolExecutor,
+      offsetCommitIntervalMs: Long)
+      extends AsyncOffsetSeqLog(sparkSession, path, executorService, offsetCommitIntervalMs) {
+
+    override def write(
+        batchMetadataFile: Path,
+        fn: OutputStream => Unit): Unit = {
+
+      throw new Exception("test fail")
+    }
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error by writing a file for the next offset i.e. 1.
+        // This should create a conflict
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/offsets/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 1") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val pw = new PrintWriter(new File(checkpointLocation + "/commits/1"))
+        pw.write("Hello, world")
+        pw.close
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include(
+              "Concurrent update to the log. Multiple streaming jobs detected for 1")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async offset log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async offset log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/offsets")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  // Tests that errors that occurred during async commit log write gets bubbled up
+  // to the main stream execution thread
+  test("bubble up async commit log write errors 2") {
+    val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "0"
+      )
+    )(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckAnswer(0),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+        // to simulate write error
+        import java.io._
+        val commitDir = new File(checkpointLocation + "/commits")
+        commitDir.setReadOnly()
+
+      },
+      AddData(inputData, 1),
+      Execute {
+        q =>
+          eventually(timeout(Span(5, Seconds))) {
+            val e = intercept[StreamingQueryException] {
+              q.processAllAvailable()
+            }
+            e.getCause.getCause.getMessage should include("Permission denied")
+          }
+      }
+    )
+  }
+
+  class MemoryStreamCapture[A: Encoder](
+      id: Int,
+      sqlContext: SQLContext,
+      numPartitions: Option[Int] = None)
+      extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
+
+    val commits = new ListBuffer[streaming.Offset]()
+    val commitThreads = new ListBuffer[Thread]()
+
+    override def commit(end: streaming.Offset): Unit = {
+      super.commit(end)
+      commits += end
+      commitThreads += Thread.currentThread()
+    }
+  }
+
+  test("commit intervals happy path") {
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+
+    val ds = inputData.toDF()
+
+    val data = new ListBuffer[Int]()
+    def startQuery(): StreamingQuery = {
+      ds.writeStream
+        .foreachBatch((ds: Dataset[Row], batchId: Long) => {
+          ds.collect.foreach((row: Row) => {
+            data += row.getInt(0)
+          }: Unit)
+        })
+        .option(ASYNC_PROGRESS_TRACKING_ENABLED, true)
+        .option(ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS, 1000)
+        .option(ASYNC_PROGRESS_TRACKING_OVERRIDE_SINK_SUPPORT_CHECK, true)
+        .option("checkpointLocation", checkpointLocation)
+        .start()
+    }
+    val query = startQuery()
+    val expected = new ListBuffer[Int]()
+    for (j <- 0 until 100) {
+      for (i <- 0 until 10) {
+        val v = i + (j * 10)
+        inputData.addData({ v })
+        expected += v
+      }
+      query.processAllAvailable()
+    }
+
+    eventually(timeout(Span(5, Seconds))) {
+      val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+        .filter(file => !file.isHidden)
+        .map(file => file.getName.toInt)
+        .sorted
+
+      offsetLogFiles should equal (commitLogFiles)
+    }
+
+    query.stop()
+
+    data should equal(expected)
+
+    val commitLogFiles = getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    val offsetLogFiles = getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted
+
+    logInfo(s"offsetLogFiles: ${offsetLogFiles}")
+    logInfo(s"commitLogFiles: ${commitLogFiles}")
+
+    val offsetLog = new AsyncOffsetSeqLog(ds.sparkSession, checkpointLocation + "/offsets", null, 0)
+    // commits received at source should match up to the ones found in the offset log
+    for (i <- 0 until inputData.commits.length) {
+      val offsetOnDisk: OffsetSeq = offsetLog.get(offsetLogFiles(i)).get
+
+      val sourceCommittedOffset: streaming.Offset = inputData.commits(i)
+
+      offsetOnDisk.offsets(0).get.json() should equal(sourceCommittedOffset.json())
+    }
+
+    // make sure that the source commits is being executed by the main stream execution thread
+    inputData.commitThreads.foreach(thread => {
+      thread.getName should include("stream execution thread for")
+      thread.getName should include(query.id.toString)
+      thread.getName should include(query.runId.toString)
+    })
+    commitLogFiles should equal(offsetLogFiles)
+  }
+
+  test("interval commits and recovery") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    val clock = new StreamManualClock
+
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // need to to set processing time to something so manual clock will work
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 0),
+      AdvanceManualClock(100),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      AdvanceManualClock(100),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      AdvanceManualClock(100),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(3),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    // batches 0 and 3 should be logged
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3))
+
+    /**
+     * restart stream
+     */
+    testStream(
+      ds,
+      extraOptions = Map(
+        ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
+        ASYNC_PROGRESS_TRACKING_CHECKPOINTING_INTERVAL_MS -> "1000"
+      )
+    )(
+      // add new data
+      StartStream(
+        Trigger.ProcessingTime("1 millisecond"),
+        checkpointLocation = checkpointLocation,
+        triggerClock = clock
+      ),
+      AddData(inputData, 4), // should persist to durable storage since first batch after restart
+      AdvanceManualClock(100),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      AdvanceManualClock(100),
+      CheckNewAnswer(5),
+      AddData(inputData, 6),
+      AdvanceManualClock(100),
+      CheckNewAnswer(6),
+      AddData(inputData, 7),
+      AdvanceManualClock(800), // should trigger offset commit write to durable storage
+      CheckNewAnswer(7),
+      Execute { q =>
+        waitPendingOffsetWrites(q)
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 3, 4, 7))
+  }
+
+  test(
+    "recovery when first offset is not zero and" +
+    " not commit log entries"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      Execute { q =>
+        q.offsetLog.purge(2)
+        getListOfFiles(checkpointLocation + "/commits").foreach(file => file.delete())
+      },
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array())
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from the beginning
+      CheckNewAnswer(0, 1, 2),
+      AddData(inputData2, 3),
+      CheckNewAnswer(3),
+      AddData(inputData2, 4),
+      CheckNewAnswer(4),
+      AddData(inputData2, 5),
+      CheckNewAnswer(5),
+      StopStream
+    )
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(2, 3, 4, 5))
+  }
+
+  test("test multiple gaps in offset and commit logs") {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      AddData(inputData, 3),
+      CheckNewAnswer(3),
+      AddData(inputData, 4),
+      CheckNewAnswer(4),
+      AddData(inputData, 5),
+      CheckNewAnswer(5),
+
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/offsets/3").delete()
+    new File(checkpointLocation + "/offsets/.3.crc").delete()
+    new File(checkpointLocation + "/offsets/4").delete()
+    new File(checkpointLocation + "/offsets/.4.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+    new File(checkpointLocation + "/commits/.3.crc").delete()
+    new File(checkpointLocation + "/commits/3").delete()
+    new File(checkpointLocation + "/commits/.4.crc").delete()
+    new File(checkpointLocation + "/commits/4").delete()
+    new File(checkpointLocation + "/commits/.5.crc").delete()
+    new File(checkpointLocation + "/commits/5").delete()
+
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      AddData(inputData2, 3),
+      AddData(inputData2, 4),
+      AddData(inputData2, 5),
+      StartStream(checkpointLocation = checkpointLocation),
+      CheckNewAnswer(3, 4, 5),
+      StopStream
+    )
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2, 5))
+  }
+
+  test(
+    "recovery when gaps in in offset and" +
+    " commit log"
+  ) {
+    val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds = inputData.toDS()
+
+    val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+
+    // create a scenario in which the offset log only contains batch 2 and commit log is empty
+    testStream(ds)(
+      StartStream(checkpointLocation = checkpointLocation),
+      AddData(inputData, 0),
+      CheckNewAnswer(0),
+      AddData(inputData, 1),
+      CheckNewAnswer(1),
+      AddData(inputData, 2),
+      CheckNewAnswer(2),
+      StopStream
+    )
+
+    new File(checkpointLocation + "/offsets/1").delete()
+    new File(checkpointLocation + "/offsets/.1.crc").delete()
+    new File(checkpointLocation + "/commits/2").delete()
+    new File(checkpointLocation + "/commits/.2.crc").delete()
+    new File(checkpointLocation + "/commits/1").delete()
+    new File(checkpointLocation + "/commits/.1.crc").delete()
+
+    getListOfFiles(checkpointLocation + "/offsets")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0, 2))
+    getListOfFiles(checkpointLocation + "/commits")
+      .filter(file => !file.isHidden)
+      .map(file => file.getName.toInt)
+      .sorted should equal(Array(0))
+
+    /**
+     * start new stream
+     */
+    val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
+    val ds2 = inputData2.toDS()
+    testStream(ds2)(
+      // add back old data
+      AddData(inputData2, 0),
+      AddData(inputData2, 1),
+      AddData(inputData2, 2),
+      StartStream(checkpointLocation = checkpointLocation),
+      // should replay from batch 1

Review Comment:
   The offset log contains 0, 2 and commit log 0.  Thus, on a restart, offsets 1 -> 2 inclusive will be processed



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org