You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ka...@apache.org on 2022/03/13 05:10:08 UTC

[spark] branch branch-3.2 updated: [SPARK-38320][SS] Fix flatMapGroupsWithState timeout in batch with data for key

This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new d1b207a  [SPARK-38320][SS] Fix flatMapGroupsWithState timeout in batch with data for key
d1b207a is described below

commit d1b207a2cda7fa1b6b52871796b5fdad52f45406
Author: Alex Balikov <al...@databricks.com>
AuthorDate: Sun Mar 13 13:55:13 2022 +0900

    [SPARK-38320][SS] Fix flatMapGroupsWithState timeout in batch with data for key
    
    ### What changes were proposed in this pull request?
    
    As described in [SPARK-38320](https://issues.apache.org/jira/browse/SPARK-38320), the bug is that it is possible for (flat)MapGroupsWithState to timeout a key even if that key received data within the same batch. This is against the documented (flat)MapGroupsWithState contract. The problem is due to the StateStore.iterator not reflecting StateStore changes made *after* its creation - this is illustrated in the test this PR adds to StateStoreSuite.scala.
    
    The fix is to *late bind* the timeoutProcessorIter timeout processing iterator in FlatMapGroupsWithStateExec to be created *after* the input iterator has been exhausted and the state changes applied to the StateStore.
    
    ### Why are the changes needed?
    
    The changes are needed to ensure the state timeout processing iterator for (flat)MapGroupsWithState is created *after* the input is processed and the changes are applied into the StateStore, otherwise it may not notice these changes (the change to the key timeout timestamp being updated as part of the input processing).
    
    ### Does this PR introduce _any_ user-facing change?
    
    No. Bug fix.
    
    ### How was this patch tested?
    
    * Added a test to StateStoreSuite.scala to illustrate the difference of state store iterator behavior across the different implementations of state stores. In particular the test illustrates the RocksDB state store iterator not reflecting state store changes made after its creation.
    * Added test to FlatMapGroupsWithStateSuite.scala which would fail with unexpected state timeout if the issue was not fixed.
    
    Closes #35810 from alex-balikov/SPARK-38320-state-iterators2.
    
    Authored-by: Alex Balikov <al...@databricks.com>
    Signed-off-by: Jungtaek Lim <ka...@gmail.com>
    (cherry picked from commit 6b64e5dc74cbdc7e2b4ae42232e9610319ad73f3)
    Signed-off-by: Jungtaek Lim <ka...@gmail.com>
---
 .../streaming/FlatMapGroupsWithStateExec.scala     | 20 +++++---
 .../streaming/state/StateStoreSuite.scala          | 58 ++++++++++++++++++++++
 .../streaming/FlatMapGroupsWithStateSuite.scala    | 48 +++++++++++++++++-
 3 files changed, 118 insertions(+), 8 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index a00a622..4c61975 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -167,12 +167,20 @@ case class FlatMapGroupsWithStateExec(
           timeoutProcessingStartTimeNs = System.nanoTime
         })
 
-    val timeoutProcessorIter =
-      CompletionIterator[InternalRow, Iterator[InternalRow]](processor.processTimedOutState(), {
-        // Note: `timeoutLatencyMs` also includes the time the parent operator took for
-        // processing output returned through iterator.
-        timeoutLatencyMs += NANOSECONDS.toMillis(System.nanoTime - timeoutProcessingStartTimeNs)
-      })
+    // SPARK-38320: Late-bind the timeout processing iterator so it is created *after* the input is
+    // processed (the input iterator is exhausted) and the state updates are written into the
+    // state store. Otherwise the iterator may not see the updates (e.g. with RocksDB state store).
+    val timeoutProcessorIter = new Iterator[InternalRow] {
+      private lazy val itr = getIterator()
+      override def hasNext = itr.hasNext
+      override def next() = itr.next()
+      private def getIterator(): Iterator[InternalRow] =
+        CompletionIterator[InternalRow, Iterator[InternalRow]](processor.processTimedOutState(), {
+          // Note: `timeoutLatencyMs` also includes the time the parent operator took for
+          // processing output returned through iterator.
+          timeoutLatencyMs += NANOSECONDS.toMillis(System.nanoTime - timeoutProcessingStartTimeNs)
+        })
+    }
 
     // Generate a iterator that returns the rows grouped by the grouping function
     // Note that this code ensures that the filtering for timeout occurs only after
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index 601b62b..dde925b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -1017,6 +1017,64 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
     }
   }
 
+  // This test illustrates state store iterator behavior differences leading to SPARK-38320.
+  testWithAllCodec("SPARK-38320 - state store iterator behavior differences") {
+    val ROCKSDB_STATE_STORE = "RocksDBStateStore"
+    val dir = newDir()
+    val storeId = StateStoreId(dir, 0L, 1)
+    var version = 0L
+
+    tryWithProviderResource(newStoreProvider(storeId)) { provider =>
+      val store = provider.getStore(version)
+      logInfo(s"Running SPARK-38320 test with state store ${store.getClass.getName}")
+
+      val itr1 = store.iterator()  // itr1 is created before any writes to the store.
+      put(store, "1", 11, 100)
+      put(store, "2", 22, 200)
+      val itr2 = store.iterator()  // itr2 is created in the middle of the writes.
+      put(store, "1", 11, 101)  // Overwrite row (1, 11)
+      put(store, "3", 33, 300)
+      val itr3 = store.iterator()  // itr3 is created after all writes.
+
+      val expected = Set(("1", 11) -> 101, ("2", 22) -> 200, ("3", 33) -> 300)  // The final state.
+      // Itr1 does not see any updates - original state of the store (SPARK-38320)
+      assert(rowPairsToDataSet(itr1) === Set.empty[Set[((String, Int), Int)]])
+      assert(rowPairsToDataSet(itr2) === expected)
+      assert(rowPairsToDataSet(itr3) === expected)
+
+      version = store.commit()
+    }
+
+    // Reload the store from the commited version and repeat the above test.
+    tryWithProviderResource(newStoreProvider(storeId)) { provider =>
+      assert(version > 0)
+      val store = provider.getStore(version)
+
+      val itr1 = store.iterator()  // itr1 is created before any writes to the store.
+      put(store, "3", 33, 301)  // Overwrite row (3, 33)
+      put(store, "4", 44, 400)
+      val itr2 = store.iterator()  // itr2 is created in the middle of the writes.
+      put(store, "4", 44, 401)  // Overwrite row (4, 44)
+      put(store, "5", 55, 500)
+      val itr3 = store.iterator()  // itr3 is created after all writes.
+
+      // The final state.
+      val expected = Set(
+        ("1", 11) -> 101, ("2", 22) -> 200, ("3", 33) -> 301, ("4", 44) -> 401, ("5", 55) -> 500)
+      if (store.getClass.getName contains ROCKSDB_STATE_STORE) {
+        // RocksDB itr1 does not see any updates - original state of the store (SPARK-38320)
+        assert(rowPairsToDataSet(itr1) === Set(
+          ("1", 11) -> 101, ("2", 22) -> 200, ("3", 33) -> 300))
+      } else {
+        assert(rowPairsToDataSet(itr1) === expected)
+      }
+      assert(rowPairsToDataSet(itr2) === expected)
+      assert(rowPairsToDataSet(itr3) === expected)
+
+      version = store.commit()
+    }
+  }
+
   test("StateStore.get") {
     quietly {
       val dir = newDir()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index d34b2b8..291eb48 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.streaming
 
 import java.io.File
-import java.sql.Date
+import java.sql.{Date, Timestamp}
 
 import org.apache.commons.io.FileUtils
 import org.scalatest.exceptions.TestFailedException
@@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
 import org.apache.spark.sql.execution.RDDScanExec
 import org.apache.spark.sql.execution.streaming._
-import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, MemoryStateStore, StateStore}
+import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, MemoryStateStore, RocksDBStateStoreProvider, StateStore}
 import org.apache.spark.sql.functions.timestamp_seconds
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.util.StreamManualClock
@@ -1519,6 +1519,50 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
     )
   }
 
+  test("SPARK-38320 - flatMapGroupsWithState state with data should not timeout") {
+    withTempDir { dir =>
+      withSQLConf(
+        (SQLConf.STREAMING_NO_DATA_MICRO_BATCHES_ENABLED.key -> "false"),
+        (SQLConf.CHECKPOINT_LOCATION.key -> dir.getCanonicalPath),
+        (SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName)) {
+
+        val inputData = MemoryStream[Timestamp]
+        val stateFunc = (key: Int, values: Iterator[Timestamp], state: GroupState[Int]) => {
+          // Should never timeout. All batches should have data and even if a timeout is set,
+          // it should get cleared when the key receives data per contract.
+          require(!state.hasTimedOut, "The state should not have timed out!")
+          // Set state and timeout once, only on the first call. The timeout should get cleared
+          // in the subsequent batch which has data for the key.
+          if (!state.exists) {
+            state.update(0)
+            state.setTimeoutTimestamp(500)  // Timeout at 500 milliseconds.
+          }
+          0
+        }
+
+        val query = inputData.toDS()
+          .withWatermark("value", "0 seconds")
+          .groupByKey(_ => 0)  // Always the same key: 0.
+          .mapGroupsWithState(GroupStateTimeout.EventTimeTimeout())(stateFunc)
+          .writeStream
+          .format("console")
+          .outputMode("update")
+          .start()
+
+        try {
+          // 2 batches. Records are routed to the same key 0. The first batch sets timeout on
+          // the key, the second batch with data should clear the timeout.
+          (1 to 2).foreach {i =>
+            inputData.addData(new Timestamp(i * 1000))
+            query.processAllAvailable()
+          }
+        } finally {
+          query.stop()
+        }
+      }
+    }
+  }
+
   testWithAllStateVersions("mapGroupsWithState - initial state - null key") {
     val mapGroupsWithStateFunc =
         (key: String, values: Iterator[String], state: GroupState[RunningCount]) => {

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org