You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ka...@apache.org on 2021/11/05 07:42:58 UTC

[spark] 01/02: WIP: benchmark test code done

This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch WIP-optimize-eviction-in-rocksdb-state-store
in repository https://gitbox.apache.org/repos/asf/spark.git

commit 21d4f96c6c7a56acce30d485a0747e1d62d7ad27
Author: Jungtaek Lim <ka...@gmail.com>
AuthorDate: Thu Sep 30 11:53:39 2021 +0900

    WIP: benchmark test code done
---
 .../streaming/FlatMapGroupsWithStateExec.scala     |   7 +-
 .../state/HDFSBackedStateStoreProvider.scala       |  42 +-
 .../sql/execution/streaming/state/RocksDB.scala    | 158 ++++-
 .../streaming/state/RocksDBStateEncoder.scala      | 135 ++++-
 .../state/RocksDBStateStoreProvider.scala          | 100 +++-
 .../sql/execution/streaming/state/StateStore.scala |  40 +-
 .../execution/streaming/state/StateStoreRDD.scala  |   8 +-
 .../state/StreamingAggregationStateManager.scala   |  23 +
 .../state/SymmetricHashJoinStateManager.scala      |   3 +-
 .../sql/execution/streaming/state/package.scala    |  12 +-
 .../execution/streaming/statefulOperators.scala    |  62 +-
 .../sql/execution/streaming/streamingLimits.scala  |   5 +-
 .../execution/benchmark/StateStoreBenchmark.scala  | 633 +++++++++++++++++++++
 ...ngSortWithSessionWindowStateIteratorSuite.scala |   7 +-
 .../streaming/state/MemoryStateStore.scala         |  14 +
 .../state/RocksDBStateStoreIntegrationSuite.scala  |  60 +-
 .../streaming/state/RocksDBStateStoreSuite.scala   |   6 +-
 .../streaming/state/StateStoreRDDSuite.scala       |  18 +-
 .../streaming/state/StateStoreSuite.scala          |  31 +-
 .../StreamingSessionWindowStateManagerSuite.scala  |   4 +-
 .../apache/spark/sql/streaming/StreamSuite.scala   |   4 +-
 .../sql/streaming/StreamingAggregationSuite.scala  |  34 +-
 22 files changed, 1270 insertions(+), 136 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index a00a622..381aeb9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -224,21 +224,23 @@ case class FlatMapGroupsWithStateExec(
           val stateStoreId = StateStoreId(
             stateInfo.get.checkpointLocation, stateInfo.get.operatorId, partitionId)
           val storeProviderId = StateStoreProviderId(stateStoreId, stateInfo.get.queryRunId)
+          // FIXME: would setting prefixScan / evict help?
           val store = StateStore.get(
             storeProviderId,
             groupingAttributes.toStructType,
             stateManager.stateSchema,
-            numColsPrefixKey = 0,
+            StatefulOperatorContext(),
             stateInfo.get.storeVersion, storeConf, hadoopConfBroadcast.value.value)
           val processor = new InputProcessor(store)
           processDataWithPartition(childDataIterator, store, processor, Some(initStateIterator))
       }
     } else {
+      // FIXME: would setting prefixScan / evict help?
       child.execute().mapPartitionsWithStateStore[InternalRow](
         getStateInfo,
         groupingAttributes.toStructType,
         stateManager.stateSchema,
-        numColsPrefixKey = 0,
+        StatefulOperatorContext(),
         session.sqlContext.sessionState,
         Some(session.sqlContext.streams.stateStoreCoordinator)
       ) { case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
@@ -334,6 +336,7 @@ case class FlatMapGroupsWithStateExec(
             throw new IllegalStateException(
               s"Cannot filter timed out keys for $timeoutConf")
         }
+        // FIXME: would setting prefixScan / evict help?
         val timingOutPairs = stateManager.getAllState(store).filter { state =>
           state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold
         }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index 75b7dae..96ba2a3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -100,8 +100,11 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
 
     /** Trait and classes representing the internal state of the store */
     trait STATE
+
     case object UPDATING extends STATE
+
     case object COMMITTED extends STATE
+
     case object ABORTED extends STATE
 
     private val newVersion = version + 1
@@ -195,6 +198,22 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
     override def toString(): String = {
       s"HDFSStateStore[id=(op=${id.operatorId},part=${id.partitionId}),dir=$baseDir]"
     }
+
+    /** FIXME: method doc */
+    override def evictOnWatermark(
+        watermarkMs: Long,
+        altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair] = {
+      // HDFSBackedStateStore doesn't index event time column
+      // FIXME: should we do this for in-memory as well?
+      iterator().filter { pair =>
+        if (altPred.apply(pair)) {
+          remove(pair.key)
+          true
+        } else {
+          false
+        }
+      }
+    }
   }
 
   def getMetricsForProvider(): Map[String, Long] = synchronized {
@@ -219,7 +238,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
 
   private def getLoadedMapForStore(version: Long): HDFSBackedStateStoreMap = synchronized {
     require(version >= 0, "Version cannot be less than 0")
-    val newMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
+    val newMap = HDFSBackedStateStoreMap.create(keySchema, operatorContext.numColsPrefixKey)
     if (version > 0) {
       newMap.putAll(loadMap(version))
     }
@@ -230,7 +249,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
       stateStoreId: StateStoreId,
       keySchema: StructType,
       valueSchema: StructType,
-      numColsPrefixKey: Int,
+      operatorContext: StatefulOperatorContext,
       storeConf: StateStoreConf,
       hadoopConf: Configuration): Unit = {
     this.stateStoreId_ = stateStoreId
@@ -240,10 +259,11 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
     this.hadoopConf = hadoopConf
     this.numberOfVersionsToRetainInMemory = storeConf.maxVersionsToRetainInMemory
 
-    require((keySchema.length == 0 && numColsPrefixKey == 0) ||
-      (keySchema.length > numColsPrefixKey), "The number of columns in the key must be " +
-      "greater than the number of columns for prefix key!")
-    this.numColsPrefixKey = numColsPrefixKey
+    require((keySchema.length == 0 && operatorContext.numColsPrefixKey == 0) ||
+      (keySchema.length > operatorContext.numColsPrefixKey), "The number of columns in the key " +
+      "must be greater than the number of columns for prefix key!")
+
+    this.operatorContext = operatorContext
 
     fm.mkdirs(baseDir)
   }
@@ -283,7 +303,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
   @volatile private var storeConf: StateStoreConf = _
   @volatile private var hadoopConf: Configuration = _
   @volatile private var numberOfVersionsToRetainInMemory: Int = _
-  @volatile private var numColsPrefixKey: Int = 0
+  @volatile private var operatorContext: StatefulOperatorContext = _
 
   // TODO: The validation should be moved to a higher level so that it works for all state store
   // implementations
@@ -401,7 +421,8 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
 
         if (lastAvailableVersion <= 0) {
           // Use an empty map for versions 0 or less.
-          lastAvailableMap = Some(HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey))
+          lastAvailableMap = Some(HDFSBackedStateStoreMap.create(keySchema,
+            operatorContext.numColsPrefixKey))
         } else {
           lastAvailableMap =
             synchronized { Option(loadedMaps.get(lastAvailableVersion)) }
@@ -411,7 +432,8 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
 
       // Load all the deltas from the version after the last available one up to the target version.
       // The last available version is the one with a full snapshot, so it doesn't need deltas.
-      val resultMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
+      val resultMap = HDFSBackedStateStoreMap.create(keySchema,
+        operatorContext.numColsPrefixKey)
       resultMap.putAll(lastAvailableMap.get)
       for (deltaVersion <- lastAvailableVersion + 1 to version) {
         updateFromDeltaFile(deltaVersion, resultMap)
@@ -554,7 +576,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
 
   private def readSnapshotFile(version: Long): Option[HDFSBackedStateStoreMap] = {
     val fileToRead = snapshotFile(version)
-    val map = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
+    val map = HDFSBackedStateStoreMap.create(keySchema, operatorContext.numColsPrefixKey)
     var input: DataInputStream = null
 
     try {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
index 1ff8b41..eed7827 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.execution.streaming.state
 
 import java.io.File
+import java.util
 import java.util.Locale
 import javax.annotation.concurrent.GuardedBy
 
@@ -50,9 +51,12 @@ import org.apache.spark.util.{NextIterator, Utils}
  * @param hadoopConf   Hadoop configuration for talking to the remote file system
  * @param loggingId    Id that will be prepended in logs for isolating concurrent RocksDBs
  */
+// FIXME: optionally receiving column families
 class RocksDB(
     dfsRootDir: String,
     val conf: RocksDBConf,
+    // TODO: change "default" to constant
+    columnFamilies: Seq[String] = Seq("default"),
     localRootDir: File = Utils.createTempDir(),
     hadoopConf: Configuration = new Configuration,
     loggingId: String = "") extends Logging {
@@ -65,16 +69,10 @@ class RocksDB(
   private val flushOptions = new FlushOptions().setWaitForFlush(true)  // wait for flush to complete
   private val writeBatch = new WriteBatchWithIndex(true)  // overwrite multiple updates to a key
 
-  private val bloomFilter = new BloomFilter()
-  private val tableFormatConfig = new BlockBasedTableConfig()
-  tableFormatConfig.setBlockSize(conf.blockSizeKB * 1024)
-  tableFormatConfig.setBlockCache(new LRUCache(conf.blockCacheSizeMB * 1024 * 1024))
-  tableFormatConfig.setFilterPolicy(bloomFilter)
-  tableFormatConfig.setFormatVersion(conf.formatVersion)
-
-  private val dbOptions = new Options() // options to open the RocksDB
+  private val dbOptions: DBOptions = new DBOptions() // options to open the RocksDB
   dbOptions.setCreateIfMissing(true)
-  dbOptions.setTableFormatConfig(tableFormatConfig)
+  dbOptions.setCreateMissingColumnFamilies(true)
+
   private val dbLogger = createLogger() // for forwarding RocksDB native logs to log4j
   dbOptions.setStatistics(new Statistics())
   private val nativeStats = dbOptions.statistics()
@@ -87,6 +85,8 @@ class RocksDB(
   private val acquireLock = new Object
 
   @volatile private var db: NativeRocksDB = _
+  @volatile private var columnFamilyHandles: util.Map[String, ColumnFamilyHandle] = _
+  @volatile private var defaultColumnFamilyHandle: ColumnFamilyHandle = _
   @volatile private var loadedVersion = -1L   // -1 = nothing valid is loaded
   @volatile private var numKeysOnLoadedVersion = 0L
   @volatile private var numKeysOnWritingVersion = 0L
@@ -96,7 +96,7 @@ class RocksDB(
   @volatile private var acquiredThreadInfo: AcquiredThreadInfo = _
 
   private val prefixScanReuseIter =
-    new java.util.concurrent.ConcurrentHashMap[Long, RocksIterator]()
+    new java.util.concurrent.ConcurrentHashMap[(Long, Int), RocksIterator]()
 
   /**
    * Load the given version of data in a native RocksDB instance.
@@ -137,7 +137,28 @@ class RocksDB(
    * @note This will return the last written value even if it was uncommitted.
    */
   def get(key: Array[Byte]): Array[Byte] = {
-    writeBatch.getFromBatchAndDB(db, readOptions, key)
+    get(defaultColumnFamilyHandle, key)
+  }
+
+  // FIXME: method doc
+  def get(cf: String, key: Array[Byte]): Array[Byte] = {
+    get(findColumnFamilyHandle(cf), key)
+  }
+
+  private def get(cfHandle: ColumnFamilyHandle, key: Array[Byte]): Array[Byte] = {
+    writeBatch.getFromBatchAndDB(db, cfHandle, readOptions, key)
+  }
+
+  def merge(key: Array[Byte], value: Array[Byte]): Unit = {
+    merge(defaultColumnFamilyHandle, key, value)
+  }
+
+  def merge(cf: String, key: Array[Byte], value: Array[Byte]): Unit = {
+    merge(findColumnFamilyHandle(cf), key, value)
+  }
+
+  private def merge(cfHandle: ColumnFamilyHandle, key: Array[Byte], value: Array[Byte]): Unit = {
+    writeBatch.merge(cfHandle, key, value)
   }
 
   /**
@@ -145,8 +166,20 @@ class RocksDB(
    * @note This update is not committed to disk until commit() is called.
    */
   def put(key: Array[Byte], value: Array[Byte]): Array[Byte] = {
-    val oldValue = writeBatch.getFromBatchAndDB(db, readOptions, key)
-    writeBatch.put(key, value)
+    put(defaultColumnFamilyHandle, key, value)
+  }
+
+  // FIXME: method doc
+  def put(cf: String, key: Array[Byte], value: Array[Byte]): Array[Byte] = {
+    put(findColumnFamilyHandle(cf), key, value)
+  }
+
+  private def put(
+      cfHandle: ColumnFamilyHandle,
+      key: Array[Byte],
+      value: Array[Byte]): Array[Byte] = {
+    val oldValue = writeBatch.getFromBatchAndDB(db, cfHandle, readOptions, key)
+    writeBatch.put(cfHandle, key, value)
     if (oldValue == null) {
       numKeysOnWritingVersion += 1
     }
@@ -158,9 +191,18 @@ class RocksDB(
    * @note This update is not committed to disk until commit() is called.
    */
   def remove(key: Array[Byte]): Array[Byte] = {
-    val value = writeBatch.getFromBatchAndDB(db, readOptions, key)
+    remove(defaultColumnFamilyHandle, key)
+  }
+
+  // FIXME: method doc
+  def remove(cf: String, key: Array[Byte]): Array[Byte] = {
+    remove(findColumnFamilyHandle(cf), key)
+  }
+
+  private def remove(cfHandle: ColumnFamilyHandle, key: Array[Byte]): Array[Byte] = {
+    val value = writeBatch.getFromBatchAndDB(db, cfHandle, readOptions, key)
     if (value != null) {
-      writeBatch.remove(key)
+      writeBatch.delete(cfHandle, key)
       numKeysOnWritingVersion -= 1
     }
     value
@@ -169,8 +211,17 @@ class RocksDB(
   /**
    * Get an iterator of all committed and uncommitted key-value pairs.
    */
-  def iterator(): Iterator[ByteArrayPair] = {
-    val iter = writeBatch.newIteratorWithBase(db.newIterator())
+  def iterator(): NextIterator[ByteArrayPair] = {
+    iterator(defaultColumnFamilyHandle)
+  }
+
+  // FIXME: doc
+  def iterator(cf: String): NextIterator[ByteArrayPair] = {
+    iterator(findColumnFamilyHandle(cf))
+  }
+
+  private def iterator(cfHandle: ColumnFamilyHandle): NextIterator[ByteArrayPair] = {
+    val iter = writeBatch.newIteratorWithBase(cfHandle, db.newIterator(cfHandle))
     logInfo(s"Getting iterator from version $loadedVersion")
     iter.seekToFirst()
 
@@ -197,11 +248,20 @@ class RocksDB(
   }
 
   def prefixScan(prefix: Array[Byte]): Iterator[ByteArrayPair] = {
+    prefixScan(defaultColumnFamilyHandle, prefix)
+  }
+
+  def prefixScan(cf: String, prefix: Array[Byte]): Iterator[ByteArrayPair] = {
+    prefixScan(findColumnFamilyHandle(cf), prefix)
+  }
+
+  private def prefixScan(
+      cfHandle: ColumnFamilyHandle, prefix: Array[Byte]): Iterator[ByteArrayPair] = {
     val threadId = Thread.currentThread().getId
-    val iter = prefixScanReuseIter.computeIfAbsent(threadId, tid => {
-      val it = writeBatch.newIteratorWithBase(db.newIterator())
+    val iter = prefixScanReuseIter.computeIfAbsent((threadId, cfHandle.getID), key => {
+      val it = writeBatch.newIteratorWithBase(cfHandle, db.newIterator(cfHandle))
       logInfo(s"Getting iterator from version $loadedVersion for prefix scan on " +
-        s"thread ID $tid")
+        s"thread ID ${key._1} and column family ID ${key._2}")
       it
     })
 
@@ -223,6 +283,14 @@ class RocksDB(
     }
   }
 
+  private def findColumnFamilyHandle(cf: String): ColumnFamilyHandle = {
+    val cfHandle = columnFamilyHandles.get(cf)
+    if (cfHandle == null) {
+      throw new IllegalArgumentException(s"Handle for column family $cf is not found")
+    }
+    cfHandle
+  }
+
   /**
    * Commit all the updates made as a version to DFS. The steps it needs to do to commits are:
    * - Write all the updates to the native RocksDB
@@ -242,11 +310,16 @@ class RocksDB(
       val writeTimeMs = timeTakenMs { db.write(writeOptions, writeBatch) }
 
       logInfo(s"Flushing updates for $newVersion")
-      val flushTimeMs = timeTakenMs { db.flush(flushOptions) }
+      val flushTimeMs = timeTakenMs {
+        db.flush(flushOptions,
+          new util.ArrayList[ColumnFamilyHandle](columnFamilyHandles.values()))
+      }
 
       val compactTimeMs = if (conf.compactOnCommit) {
         logInfo("Compacting")
-        timeTakenMs { db.compactRange() }
+        timeTakenMs {
+          columnFamilyHandles.values().forEach(cfHandle => db.compactRange(cfHandle))
+        }
       } else 0
 
       logInfo("Pausing background work")
@@ -279,6 +352,7 @@ class RocksDB(
       loadedVersion
     } catch {
       case t: Throwable =>
+        logWarning(s"ERROR! exc: $t", t)
         loadedVersion = -1  // invalidate loaded version
         throw t
     } finally {
@@ -422,12 +496,43 @@ class RocksDB(
 
   private def openDB(): Unit = {
     assert(db == null)
-    db = NativeRocksDB.open(dbOptions, workingDir.toString)
+
+    val columnFamilyDescriptors = new util.ArrayList[ColumnFamilyDescriptor]()
+    columnFamilies.foreach { cf =>
+      val bloomFilter = new BloomFilter()
+      val tableFormatConfig = new BlockBasedTableConfig()
+      tableFormatConfig.setBlockSize(conf.blockSizeKB * 1024)
+      tableFormatConfig.setBlockCache(new LRUCache(conf.blockCacheSizeMB * 1024 * 1024))
+      tableFormatConfig.setFilterPolicy(bloomFilter)
+      tableFormatConfig.setFormatVersion(conf.formatVersion)
+
+      val columnFamilyOptions = new ColumnFamilyOptions()
+      columnFamilyOptions.setTableFormatConfig(tableFormatConfig)
+      columnFamilyDescriptors.add(new ColumnFamilyDescriptor(cf.getBytes(), columnFamilyOptions))
+    }
+
+    val cfHandles = new util.ArrayList[ColumnFamilyHandle](columnFamilyDescriptors.size())
+    db = NativeRocksDB.open(dbOptions, workingDir.toString, columnFamilyDescriptors,
+      cfHandles)
+
+    columnFamilyHandles = new util.HashMap[String, ColumnFamilyHandle]()
+    columnFamilies.indices.foreach { idx =>
+      columnFamilyHandles.put(columnFamilies(idx), cfHandles.get(idx))
+    }
+
+    // FIXME: constant
+    defaultColumnFamilyHandle = columnFamilyHandles.get("default")
+
     logInfo(s"Opened DB with conf ${conf}")
   }
 
   private def closeDB(): Unit = {
     if (db != null) {
+      columnFamilyHandles.entrySet().forEach(pair => db.destroyColumnFamilyHandle(pair.getValue))
+      columnFamilyHandles.clear()
+      columnFamilyHandles = null
+      defaultColumnFamilyHandle = null
+
       db.close()
       db = null
     }
@@ -441,10 +546,17 @@ class RocksDB(
         // Warn is mapped to info because RocksDB warn is too verbose
         // (e.g. dumps non-warning stuff like stats)
         val loggingFunc: ( => String) => Unit = infoLogLevel match {
+          /*
           case InfoLogLevel.FATAL_LEVEL | InfoLogLevel.ERROR_LEVEL => logError(_)
           case InfoLogLevel.WARN_LEVEL | InfoLogLevel.INFO_LEVEL => logInfo(_)
           case InfoLogLevel.DEBUG_LEVEL => logDebug(_)
           case _ => logTrace(_)
+           */
+          case InfoLogLevel.FATAL_LEVEL | InfoLogLevel.ERROR_LEVEL => logError(_)
+          case InfoLogLevel.WARN_LEVEL => logWarning(_)
+          case InfoLogLevel.INFO_LEVEL => logInfo(_)
+          case InfoLogLevel.DEBUG_LEVEL => logDebug(_)
+          case _ => logTrace(_)
         }
         loggingFunc(s"[NativeRocksDB-${infoLogLevel.getValue}] $logMsg")
       }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
index 81755e5..323826d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
@@ -17,9 +17,12 @@
 
 package org.apache.spark.sql.execution.streaming.state
 
+import java.lang.{Long => JLong}
+import java.nio.ByteOrder
+
 import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow, UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{STATE_ENCODING_NUM_VERSION_BYTES, STATE_ENCODING_VERSION}
-import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.sql.types.{StructField, StructType, TimestampType}
 import org.apache.spark.unsafe.Platform
 
 sealed trait RocksDBStateEncoder {
@@ -27,6 +30,11 @@ sealed trait RocksDBStateEncoder {
   def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte]
   def extractPrefixKey(key: UnsafeRow): UnsafeRow
 
+  def supportEventTimeIndex: Boolean
+  def extractEventTime(key: UnsafeRow): Long
+  def encodeEventTimeIndexKey(timestamp: Long, encodedKey: Array[Byte]): Array[Byte]
+  def decodeEventTimeIndexKey(eventTimeBytes: Array[Byte]): (Long, Array[Byte])
+
   def encodeKey(row: UnsafeRow): Array[Byte]
   def encodeValue(row: UnsafeRow): Array[Byte]
 
@@ -39,11 +47,13 @@ object RocksDBStateEncoder {
   def getEncoder(
       keySchema: StructType,
       valueSchema: StructType,
-      numColsPrefixKey: Int): RocksDBStateEncoder = {
+      numColsPrefixKey: Int,
+      eventTimeColIdx: Array[Int]): RocksDBStateEncoder = {
     if (numColsPrefixKey > 0) {
+      // FIXME: need to deal with prefix case as well
       new PrefixKeyScanStateEncoder(keySchema, valueSchema, numColsPrefixKey)
     } else {
-      new NoPrefixKeyStateEncoder(keySchema, valueSchema)
+      new NoPrefixKeyStateEncoder(keySchema, valueSchema, eventTimeColIdx)
     }
   }
 
@@ -86,6 +96,39 @@ object RocksDBStateEncoder {
   }
 }
 
+object BinarySortable {
+  private val NATIVE_BYTE_ORDER: ByteOrder = ByteOrder.nativeOrder()
+  private val LITTLE_ENDIAN: Boolean = NATIVE_BYTE_ORDER == ByteOrder.LITTLE_ENDIAN
+  private val SIGN_BIT_LONG: Long = (1L << 63)
+
+  def encodeToBinarySortableLong(value: Long): Long = {
+    // Flip the sign bit. This simply works with binary form of comparison, as negative values
+    // are placed in reversed order, and positive values are placed in sequential order.
+    val encoded = value ^ SIGN_BIT_LONG
+
+    // We have to retain the sequence of bytes as same as BIG_ENDIAN, as the binary form will be
+    // compared in sequential order (via offset).
+    if (LITTLE_ENDIAN) {
+      JLong.reverseBytes(encoded)
+    } else {
+      encoded
+    }
+  }
+
+  def decodeBinarySortableLong(encoded: Long): Long = {
+    // The value is based on BIG_ENDIAN. If the system is LITTLE_ENDIAN, we should convert it to
+    // follow the system.
+    val decoded = if (LITTLE_ENDIAN) {
+      JLong.reverseBytes(encoded)
+    } else {
+      encoded
+    }
+
+    // Flip the sign bit as encode function does.
+    decoded ^ SIGN_BIT_LONG
+  }
+}
+
 class PrefixKeyScanStateEncoder(
     keySchema: StructType,
     valueSchema: StructType,
@@ -185,6 +228,23 @@ class PrefixKeyScanStateEncoder(
   }
 
   override def supportPrefixKeyScan: Boolean = true
+
+  override def supportEventTimeIndex: Boolean = false
+
+  // FIXME: fix me!
+  def extractEventTime(key: UnsafeRow): Long = {
+    throw new IllegalStateException("This encoder doesn't support event time index!")
+  }
+
+  // FIXME: fix me!
+  def encodeEventTimeIndexKey(timestamp: Long, encodedKey: Array[Byte]): Array[Byte] = {
+    throw new IllegalStateException("This encoder doesn't support event time index!")
+  }
+
+  // FIXME: fix me!
+  def decodeEventTimeIndexKey(eventTimeBytes: Array[Byte]): (Long, Array[Byte]) = {
+    throw new IllegalStateException("This encoder doesn't support event time index!")
+  }
 }
 
 /**
@@ -197,8 +257,10 @@ class PrefixKeyScanStateEncoder(
  *    (offset 0 is the version byte of value 0). That is, if the unsafe row has N bytes,
  *    then the generated array byte will be N+1 bytes.
  */
-class NoPrefixKeyStateEncoder(keySchema: StructType, valueSchema: StructType)
-  extends RocksDBStateEncoder {
+class NoPrefixKeyStateEncoder(
+    keySchema: StructType,
+    valueSchema: StructType,
+    eventTimeColIdx: Array[Int]) extends RocksDBStateEncoder {
 
   import RocksDBStateEncoder._
 
@@ -207,6 +269,32 @@ class NoPrefixKeyStateEncoder(keySchema: StructType, valueSchema: StructType)
   private val valueRow = new UnsafeRow(valueSchema.size)
   private val rowTuple = new UnsafeRowPair()
 
+  validateColumnTypeOnEventTimeColumn()
+
+  private def validateColumnTypeOnEventTimeColumn(): Unit = {
+    if (eventTimeColIdx.nonEmpty) {
+      var curSchema: StructType = keySchema
+      eventTimeColIdx.dropRight(1).foreach { idx =>
+        curSchema(idx).dataType match {
+          case stType: StructType =>
+            curSchema = stType
+          case _ =>
+            // FIXME: better error message
+            throw new IllegalStateException("event time column is not properly specified! " +
+              s"index: ${eventTimeColIdx.mkString("(", ", ", ")")} / key schema: $keySchema")
+        }
+      }
+
+      curSchema(eventTimeColIdx.last).dataType match {
+        case _: TimestampType =>
+        case _ =>
+          // FIXME: better error message
+          throw new IllegalStateException("event time column is not properly specified! " +
+            s"index: ${eventTimeColIdx.mkString("(", ", ", ")")} / key schema: $keySchema")
+      }
+    }
+  }
+
   override def encodeKey(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)
 
   override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)
@@ -249,4 +337,41 @@ class NoPrefixKeyStateEncoder(keySchema: StructType, valueSchema: StructType)
   override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = {
     throw new IllegalStateException("This encoder doesn't support prefix key!")
   }
+
+  override def supportEventTimeIndex: Boolean = eventTimeColIdx.nonEmpty
+
+  override def extractEventTime(key: UnsafeRow): Long = {
+    var curRow: UnsafeRow = key
+    var curSchema: StructType = keySchema
+
+    eventTimeColIdx.dropRight(1).foreach { idx =>
+      // validation is done in initialization phase
+      curSchema = curSchema(idx).dataType.asInstanceOf[StructType]
+      curRow = curRow.getStruct(idx, curSchema.length)
+    }
+
+    curRow.getLong(eventTimeColIdx.last) / 1000
+  }
+
+  override def encodeEventTimeIndexKey(timestamp: Long, encodedKey: Array[Byte]): Array[Byte] = {
+    val newKey = new Array[Byte](8 + encodedKey.length)
+
+    Platform.putLong(newKey, Platform.BYTE_ARRAY_OFFSET,
+      BinarySortable.encodeToBinarySortableLong(timestamp))
+    Platform.copyMemory(encodedKey, Platform.BYTE_ARRAY_OFFSET, newKey,
+      Platform.BYTE_ARRAY_OFFSET + 8, encodedKey.length)
+
+    newKey
+  }
+
+  override def decodeEventTimeIndexKey(eventTimeBytes: Array[Byte]): (Long, Array[Byte]) = {
+    val encoded = Platform.getLong(eventTimeBytes, Platform.BYTE_ARRAY_OFFSET)
+    val timestamp = BinarySortable.decodeBinarySortableLong(encoded)
+
+    val key = new Array[Byte](eventTimeBytes.length - 8)
+    Platform.copyMemory(eventTimeBytes, Platform.BYTE_ARRAY_OFFSET + 8,
+      key, Platform.BYTE_ARRAY_OFFSET, eventTimeBytes.length - 8)
+
+    (timestamp, key)
+  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
index a2b33c2..1d66220 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
@@ -25,7 +25,8 @@ import org.apache.spark.{SparkConf, SparkEnv}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.util.Utils
+import org.apache.spark.unsafe.Platform
+import org.apache.spark.util.{NextIterator, Utils}
 
 private[sql] class RocksDBStateStoreProvider
   extends StateStoreProvider with Logging with Closeable {
@@ -60,12 +61,23 @@ private[sql] class RocksDBStateStoreProvider
       verify(state == UPDATING, "Cannot put after already committed or aborted")
       verify(key != null, "Key cannot be null")
       require(value != null, "Cannot put a null value")
-      rocksDB.put(encoder.encodeKey(key), encoder.encodeValue(value))
+
+      val encodedKey = encoder.encodeKey(key)
+      val encodedValue = encoder.encodeValue(value)
+      rocksDB.put(encodedKey, encodedValue)
+
+      if (encoder.supportEventTimeIndex) {
+        val timestamp = encoder.extractEventTime(key)
+        val tsKey = encoder.encodeEventTimeIndexKey(timestamp, encodedKey)
+
+        rocksDB.put(CF_EVENT_TIME_INDEX, tsKey, Array.empty)
+      }
     }
 
     override def remove(key: UnsafeRow): Unit = {
       verify(state == UPDATING, "Cannot remove after already committed or aborted")
       verify(key != null, "Key cannot be null")
+      // FIXME: this should reflect the index
       rocksDB.remove(encoder.encodeKey(key))
     }
 
@@ -161,13 +173,75 @@ private[sql] class RocksDBStateStoreProvider
 
     /** Return the [[RocksDB]] instance in this store. This is exposed mainly for testing. */
     def dbInstance(): RocksDB = rocksDB
+
+    /** FIXME: method doc */
+    override def evictOnWatermark(
+        watermarkMs: Long,
+        altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair] = {
+
+      if (encoder.supportEventTimeIndex) {
+        val kv = new ByteArrayPair()
+
+        // FIXME: DEBUG
+        // logWarning(s"DEBUG: start iterating event time index, watermark $watermarkMs")
+
+        new NextIterator[UnsafeRowPair] {
+          private val iter = rocksDB.iterator(CF_EVENT_TIME_INDEX)
+          override protected def getNext(): UnsafeRowPair = {
+            if (iter.hasNext) {
+              val pair = iter.next()
+
+              val encodedTs = Platform.getLong(pair.key, Platform.BYTE_ARRAY_OFFSET)
+              val decodedTs = BinarySortable.decodeBinarySortableLong(encodedTs)
+
+              // FIXME: DEBUG
+              // logWarning(s"DEBUG: decoded TS: $decodedTs")
+
+              if (decodedTs > watermarkMs) {
+                finished = true
+                null
+              } else {
+                // FIXME: can we leverage deleteRange to bulk delete on index?
+                rocksDB.remove(CF_EVENT_TIME_INDEX, pair.key)
+                val (_, encodedKey) = encoder.decodeEventTimeIndexKey(pair.key)
+                val value = rocksDB.get(encodedKey)
+                if (value == null) {
+                  throw new IllegalStateException("Event time index has been broken!")
+                }
+                kv.set(encodedKey, value)
+                val rowPair = encoder.decode(kv)
+                rocksDB.remove(encodedKey)
+                rowPair
+              }
+            } else {
+              finished = true
+              null
+            }
+          }
+
+          override protected def close(): Unit = {
+            iter.closeIfNeeded()
+          }
+        }
+      } else {
+        rocksDB.iterator().flatMap { kv =>
+          val rowPair = encoder.decode(kv)
+          if (altPred(rowPair)) {
+            rocksDB.remove(kv.key)
+            Some(rowPair)
+          } else {
+            None
+          }
+        }
+      }
+    }
   }
 
   override def init(
       stateStoreId: StateStoreId,
       keySchema: StructType,
       valueSchema: StructType,
-      numColsPrefixKey: Int,
+      operatorContext: StatefulOperatorContext,
       storeConf: StateStoreConf,
       hadoopConf: Configuration): Unit = {
     this.stateStoreId_ = stateStoreId
@@ -176,11 +250,14 @@ private[sql] class RocksDBStateStoreProvider
     this.storeConf = storeConf
     this.hadoopConf = hadoopConf
 
-    require((keySchema.length == 0 && numColsPrefixKey == 0) ||
-      (keySchema.length > numColsPrefixKey), "The number of columns in the key must be " +
-      "greater than the number of columns for prefix key!")
+    require((keySchema.length == 0 && operatorContext.numColsPrefixKey == 0) ||
+      (keySchema.length > operatorContext.numColsPrefixKey), "The number of columns in the key " +
+      "must be greater than the number of columns for prefix key!")
 
-    this.encoder = RocksDBStateEncoder.getEncoder(keySchema, valueSchema, numColsPrefixKey)
+    this.operatorContext = operatorContext
+
+    this.encoder = RocksDBStateEncoder.getEncoder(keySchema, valueSchema,
+      operatorContext.numColsPrefixKey, operatorContext.eventTimeColIdx)
 
     rocksDB // lazy initialization
   }
@@ -212,6 +289,7 @@ private[sql] class RocksDBStateStoreProvider
   @volatile private var valueSchema: StructType = _
   @volatile private var storeConf: StateStoreConf = _
   @volatile private var hadoopConf: Configuration = _
+  @volatile private var operatorContext: StatefulOperatorContext = _
 
   private[sql] lazy val rocksDB = {
     val dfsRootDir = stateStoreId.storeCheckpointLocation().toString
@@ -219,7 +297,10 @@ private[sql] class RocksDBStateStoreProvider
       s"partId=${stateStoreId.partitionId},name=${stateStoreId.storeName})"
     val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
     val localRootDir = Utils.createTempDir(Utils.getLocalDir(sparkConf), storeIdStr)
-    new RocksDB(dfsRootDir, RocksDBConf(storeConf), localRootDir, hadoopConf, storeIdStr)
+    new RocksDB(dfsRootDir, RocksDBConf(storeConf),
+      columnFamilies = Seq("default", RocksDBStateStoreProvider.CF_EVENT_TIME_INDEX),
+      localRootDir = localRootDir,
+      hadoopConf = hadoopConf, loggingId = storeIdStr)
   }
 
   @volatile private var encoder: RocksDBStateEncoder = _
@@ -234,6 +315,9 @@ object RocksDBStateStoreProvider {
   val STATE_ENCODING_NUM_VERSION_BYTES = 1
   val STATE_ENCODING_VERSION: Byte = 0
 
+  // reserved column families
+  val CF_EVENT_TIME_INDEX: String = "__event_time_idx"
+
   // Native operation latencies report as latency in microseconds
   // as SQLMetrics support millis. Convert the value to millis
   val CUSTOM_METRIC_GET_TIME = StateStoreCustomTimingMetric(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index 5020638..cee9ad1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -130,6 +130,11 @@ trait StateStore extends ReadStateStore {
    */
   override def iterator(): Iterator[UnsafeRowPair]
 
+  /** FIXME: method doc */
+  def evictOnWatermark(
+      watermarkMs: Long,
+      altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair]
+
   /** Current metrics of the state store */
   def metrics: StateStoreMetrics
 
@@ -229,6 +234,19 @@ class InvalidUnsafeRowException
     "checkpoint or use the legacy Spark version to process the streaming state.", null)
 
 /**
+ * FIXME: classdoc
+ *
+ * @param numColsPrefixKey The number of leftmost columns to be used as prefix key.
+ *                         A value not greater than 0 means the operator doesn't activate prefix
+ *                         key, and the operator should not call prefixScan method in StateStore.
+ * @param eventTimeColIdx  column specifying event time for the row. only works when the column
+ *                         is in the key. array type as the column can be struct type.
+ */
+case class StatefulOperatorContext(
+    numColsPrefixKey: Int = 0,
+    eventTimeColIdx: Array[Int] = Array.empty)
+
+/**
  * Trait representing a provider that provide [[StateStore]] instances representing
  * versions of state data.
  *
@@ -255,9 +273,7 @@ trait StateStoreProvider {
    * @param stateStoreId Id of the versioned StateStores that this provider will generate
    * @param keySchema Schema of keys to be stored
    * @param valueSchema Schema of value to be stored
-   * @param numColsPrefixKey The number of leftmost columns to be used as prefix key.
-   *                         A value not greater than 0 means the operator doesn't activate prefix
-   *                         key, and the operator should not call prefixScan method in StateStore.
+   * @param operatorContext FIXME: ...
    * @param storeConfs Configurations used by the StateStores
    * @param hadoopConf Hadoop configuration that could be used by StateStore to save state data
    */
@@ -265,7 +281,7 @@ trait StateStoreProvider {
       stateStoreId: StateStoreId,
       keySchema: StructType,
       valueSchema: StructType,
-      numColsPrefixKey: Int,
+      operatorContext: StatefulOperatorContext,
       storeConfs: StateStoreConf,
       hadoopConf: Configuration): Unit
 
@@ -318,11 +334,11 @@ object StateStoreProvider {
       providerId: StateStoreProviderId,
       keySchema: StructType,
       valueSchema: StructType,
-      numColsPrefixKey: Int,
+      operatorContext: StatefulOperatorContext,
       storeConf: StateStoreConf,
       hadoopConf: Configuration): StateStoreProvider = {
     val provider = create(storeConf.providerClass)
-    provider.init(providerId.storeId, keySchema, valueSchema, numColsPrefixKey,
+    provider.init(providerId.storeId, keySchema, valueSchema, operatorContext,
       storeConf, hadoopConf)
     provider
   }
@@ -471,13 +487,13 @@ object StateStore extends Logging {
       storeProviderId: StateStoreProviderId,
       keySchema: StructType,
       valueSchema: StructType,
-      numColsPrefixKey: Int,
+      operatorContext: StatefulOperatorContext,
       version: Long,
       storeConf: StateStoreConf,
       hadoopConf: Configuration): ReadStateStore = {
     require(version >= 0)
     val storeProvider = getStateStoreProvider(storeProviderId, keySchema, valueSchema,
-      numColsPrefixKey, storeConf, hadoopConf)
+      operatorContext, storeConf, hadoopConf)
     storeProvider.getReadStore(version)
   }
 
@@ -486,13 +502,13 @@ object StateStore extends Logging {
       storeProviderId: StateStoreProviderId,
       keySchema: StructType,
       valueSchema: StructType,
-      numColsPrefixKey: Int,
+      operatorContext: StatefulOperatorContext,
       version: Long,
       storeConf: StateStoreConf,
       hadoopConf: Configuration): StateStore = {
     require(version >= 0)
     val storeProvider = getStateStoreProvider(storeProviderId, keySchema, valueSchema,
-      numColsPrefixKey, storeConf, hadoopConf)
+      operatorContext, storeConf, hadoopConf)
     storeProvider.getStore(version)
   }
 
@@ -500,7 +516,7 @@ object StateStore extends Logging {
       storeProviderId: StateStoreProviderId,
       keySchema: StructType,
       valueSchema: StructType,
-      numColsPrefixKey: Int,
+      operatorContext: StatefulOperatorContext,
       storeConf: StateStoreConf,
       hadoopConf: Configuration): StateStoreProvider = {
     loadedProviders.synchronized {
@@ -527,7 +543,7 @@ object StateStore extends Logging {
       val provider = loadedProviders.getOrElseUpdate(
         storeProviderId,
         StateStoreProvider.createAndInit(
-          storeProviderId, keySchema, valueSchema, numColsPrefixKey, storeConf, hadoopConf)
+          storeProviderId, keySchema, valueSchema, operatorContext, storeConf, hadoopConf)
       )
       val otherProviderIds = loadedProviders.keys.filter(_ != storeProviderId).toSeq
       val providerIdsToUnload = reportActiveStoreInstance(storeProviderId, otherProviderIds)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
index fbe83ad..33d83b7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
@@ -74,7 +74,7 @@ class ReadStateStoreRDD[T: ClassTag, U: ClassTag](
     storeVersion: Long,
     keySchema: StructType,
     valueSchema: StructType,
-    numColsPrefixKey: Int,
+    operatorContext: StatefulOperatorContext,
     sessionState: SessionState,
     @transient private val storeCoordinator: Option[StateStoreCoordinatorRef],
     extraOptions: Map[String, String] = Map.empty)
@@ -87,7 +87,7 @@ class ReadStateStoreRDD[T: ClassTag, U: ClassTag](
     val storeProviderId = getStateProviderId(partition)
 
     val store = StateStore.getReadOnly(
-      storeProviderId, keySchema, valueSchema, numColsPrefixKey, storeVersion,
+      storeProviderId, keySchema, valueSchema, operatorContext, storeVersion,
       storeConf, hadoopConfBroadcast.value.value)
     val inputIter = dataRDD.iterator(partition, ctxt)
     storeReadFunction(store, inputIter)
@@ -108,7 +108,7 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
     storeVersion: Long,
     keySchema: StructType,
     valueSchema: StructType,
-    numColsPrefixKey: Int,
+    operatorContext: StatefulOperatorContext,
     sessionState: SessionState,
     @transient private val storeCoordinator: Option[StateStoreCoordinatorRef],
     extraOptions: Map[String, String] = Map.empty)
@@ -121,7 +121,7 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
     val storeProviderId = getStateProviderId(partition)
 
     val store = StateStore.get(
-      storeProviderId, keySchema, valueSchema, numColsPrefixKey, storeVersion,
+      storeProviderId, keySchema, valueSchema, operatorContext, storeVersion,
       storeConf, hadoopConfBroadcast.value.value)
     val inputIter = dataRDD.iterator(partition, ctxt)
     storeUpdateFunction(store, inputIter)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala
index 36138f1..c6b63cf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala
@@ -51,6 +51,12 @@ sealed trait StreamingAggregationStateManager extends Serializable {
   /** Remove a single non-null key from the target state store. */
   def remove(store: StateStore, key: UnsafeRow): Unit
 
+  // FIXME: method doc!
+  def evictOnWatermark(
+    store: StateStore,
+    watermarkMs: Long,
+    altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair]
+
   /** Return an iterator containing all the key-value pairs in target state store. */
   def iterator(store: ReadStateStore): Iterator[UnsafeRowPair]
 
@@ -128,6 +134,13 @@ class StreamingAggregationStateManagerImplV1(
   override def values(store: ReadStateStore): Iterator[UnsafeRow] = {
     store.iterator().map(_.value)
   }
+
+  override def evictOnWatermark(
+      store: StateStore,
+      watermarkMs: Long,
+      altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair] = {
+    store.evictOnWatermark(watermarkMs, altPred)
+  }
 }
 
 /**
@@ -186,6 +199,16 @@ class StreamingAggregationStateManagerImplV2(
     store.iterator().map(rowPair => new UnsafeRowPair(rowPair.key, restoreOriginalRow(rowPair)))
   }
 
+
+  override def evictOnWatermark(
+      store: StateStore,
+      watermarkMs: Long,
+      altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair] = {
+    store.evictOnWatermark(watermarkMs, altPred).map { rowPair =>
+      new UnsafeRowPair(rowPair.key, restoreOriginalRow(rowPair))
+    }
+  }
+
   override def values(store: ReadStateStore): Iterator[UnsafeRow] = {
     store.iterator().map(rowPair => restoreOriginalRow(rowPair))
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
index f301d23..ce56845 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
@@ -363,8 +363,9 @@ class SymmetricHashJoinStateManager(
     protected def getStateStore(keySchema: StructType, valueSchema: StructType): StateStore = {
       val storeProviderId = StateStoreProviderId(
         stateInfo.get, partitionId, getStateStoreName(joinSide, stateStoreType))
+      // FIXME: would setting prefixScan / evict help?
       val store = StateStore.get(
-        storeProviderId, keySchema, valueSchema, numColsPrefixKey = 0,
+        storeProviderId, keySchema, valueSchema, StatefulOperatorContext(),
         stateInfo.get.storeVersion, storeConf, hadoopConf)
       logInfo(s"Loaded store ${store.id}")
       store
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
index 01ff72b..7cd25eb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
@@ -35,14 +35,14 @@ package object state {
         stateInfo: StatefulOperatorStateInfo,
         keySchema: StructType,
         valueSchema: StructType,
-        numColsPrefixKey: Int)(
+        operatorContext: StatefulOperatorContext)(
         storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = {
 
       mapPartitionsWithStateStore(
         stateInfo,
         keySchema,
         valueSchema,
-        numColsPrefixKey,
+        operatorContext,
         sqlContext.sessionState,
         Some(sqlContext.streams.stateStoreCoordinator))(
         storeUpdateFunction)
@@ -53,7 +53,7 @@ package object state {
         stateInfo: StatefulOperatorStateInfo,
         keySchema: StructType,
         valueSchema: StructType,
-        numColsPrefixKey: Int,
+        operatorContext: StatefulOperatorContext,
         sessionState: SessionState,
         storeCoordinator: Option[StateStoreCoordinatorRef],
         extraOptions: Map[String, String] = Map.empty)(
@@ -77,7 +77,7 @@ package object state {
         stateInfo.storeVersion,
         keySchema,
         valueSchema,
-        numColsPrefixKey,
+        operatorContext,
         sessionState,
         storeCoordinator,
         extraOptions)
@@ -88,7 +88,7 @@ package object state {
         stateInfo: StatefulOperatorStateInfo,
         keySchema: StructType,
         valueSchema: StructType,
-        numColsPrefixKey: Int,
+        operatorContext: StatefulOperatorContext,
         sessionState: SessionState,
         storeCoordinator: Option[StateStoreCoordinatorRef],
         extraOptions: Map[String, String] = Map.empty)(
@@ -112,7 +112,7 @@ package object state {
         stateInfo.storeVersion,
         keySchema,
         valueSchema,
-        numColsPrefixKey,
+        operatorContext,
         sessionState,
         storeCoordinator,
         extraOptions)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index 3431823..923d94e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -237,11 +237,10 @@ trait WatermarkSupport extends SparkPlan {
   protected def removeKeysOlderThanWatermark(store: StateStore): Unit = {
     if (watermarkPredicateForKeys.nonEmpty) {
       val numRemovedStateRows = longMetric("numRemovedStateRows")
-      store.iterator().foreach { rowPair =>
-        if (watermarkPredicateForKeys.get.eval(rowPair.key)) {
-          store.remove(rowPair.key)
-          numRemovedStateRows += 1
-        }
+      store.evictOnWatermark(eventTimeWatermark.get, pair => {
+        watermarkPredicateForKeys.get.eval(pair.key)
+      }).foreach { _ =>
+        numRemovedStateRows += 1
       }
     }
   }
@@ -251,11 +250,10 @@ trait WatermarkSupport extends SparkPlan {
       store: StateStore): Unit = {
     if (watermarkPredicateForKeys.nonEmpty) {
       val numRemovedStateRows = longMetric("numRemovedStateRows")
-      storeManager.keys(store).foreach { keyRow =>
-        if (watermarkPredicateForKeys.get.eval(keyRow)) {
-          storeManager.remove(store, keyRow)
-          numRemovedStateRows += 1
-        }
+      storeManager.evictOnWatermark(store,
+        eventTimeWatermark.get, pair => watermarkPredicateForKeys.get.eval(pair.key)
+      ).foreach { _ =>
+        numRemovedStateRows += 1
       }
     }
   }
@@ -307,7 +305,8 @@ case class StateStoreRestoreExec(
       getStateInfo,
       keyExpressions.toStructType,
       stateManager.getStateValueSchema,
-      numColsPrefixKey = 0,
+      // FIXME: set event time column here!
+      StatefulOperatorContext(),
       session.sessionState,
       Some(session.streams.stateStoreCoordinator)) { case (store, iter) =>
         val hasInput = iter.hasNext
@@ -365,11 +364,26 @@ case class StateStoreSaveExec(
     assert(outputMode.nonEmpty,
       "Incorrect planning in IncrementalExecution, outputMode has not been set")
 
+    val eventTimeIdx = keyExpressions.indexWhere(_.metadata.contains(EventTimeWatermark.delayKey))
+    val eventTimeColIdx = if (eventTimeIdx >= 0) {
+      keyExpressions.toStructType(eventTimeIdx).dataType match {
+        // FIXME: for now, we only consider window operation here, as we do the same in
+        //   WatermarkSupport.watermarkExpression
+        case StructType(_) => Array[Int](eventTimeIdx, 1)
+        case TimestampType => Array[Int](eventTimeIdx)
+        case _ => throw new IllegalStateException(
+          "The type of event time column should be timestamp")
+      }
+    } else {
+      Array.empty[Int]
+    }
+
     child.execute().mapPartitionsWithStateStore(
       getStateInfo,
       keyExpressions.toStructType,
       stateManager.getStateValueSchema,
-      numColsPrefixKey = 0,
+      // FIXME: set event time column here!
+      StatefulOperatorContext(eventTimeColIdx = eventTimeColIdx),
       session.sessionState,
       Some(session.streams.stateStoreCoordinator)) { (store, iter) =>
         val numOutputRows = longMetric("numOutputRows")
@@ -414,18 +428,16 @@ case class StateStoreSaveExec(
             }
 
             val removalStartTimeNs = System.nanoTime
-            val rangeIter = stateManager.iterator(store)
+            val evictedIter = stateManager.evictOnWatermark(store,
+              eventTimeWatermark.get, pair => watermarkPredicateForKeys.get.eval(pair.key))
 
             new NextIterator[InternalRow] {
               override protected def getNext(): InternalRow = {
                 var removedValueRow: InternalRow = null
-                while(rangeIter.hasNext && removedValueRow == null) {
-                  val rowPair = rangeIter.next()
-                  if (watermarkPredicateForKeys.get.eval(rowPair.key)) {
-                    stateManager.remove(store, rowPair.key)
-                    numRemovedStateRows += 1
-                    removedValueRow = rowPair.value
-                  }
+                while(evictedIter.hasNext && removedValueRow == null) {
+                  val rowPair = evictedIter.next()
+                  numRemovedStateRows += 1
+                  removedValueRow = rowPair.value
                 }
                 if (removedValueRow == null) {
                   finished = true
@@ -541,7 +553,8 @@ case class SessionWindowStateStoreRestoreExec(
       getStateInfo,
       stateManager.getStateKeySchema,
       stateManager.getStateValueSchema,
-      numColsPrefixKey = stateManager.getNumColsForPrefixKey,
+      // FIXME: set event time column here!
+      StatefulOperatorContext(stateManager.getNumColsForPrefixKey),
       session.sessionState,
       Some(session.streams.stateStoreCoordinator)) { case (store, iter) =>
 
@@ -618,7 +631,8 @@ case class SessionWindowStateStoreSaveExec(
       getStateInfo,
       stateManager.getStateKeySchema,
       stateManager.getStateValueSchema,
-      numColsPrefixKey = stateManager.getNumColsForPrefixKey,
+      // FIXME: set event time column!
+      StatefulOperatorContext(numColsPrefixKey = stateManager.getNumColsForPrefixKey),
       session.sessionState,
       Some(session.streams.stateStoreCoordinator)) { case (store, iter) =>
 
@@ -652,6 +666,7 @@ case class SessionWindowStateStoreSaveExec(
 
           val removalStartTimeNs = System.nanoTime
           new NextIterator[InternalRow] {
+            // FIXME: can we optimize this case as well?
             private val removedIter = stateManager.removeByValueCondition(
               store, watermarkPredicateForData.get.eval)
 
@@ -751,7 +766,8 @@ case class StreamingDeduplicateExec(
       getStateInfo,
       keyExpressions.toStructType,
       child.output.toStructType,
-      numColsPrefixKey = 0,
+      // FIXME: set event time column!
+      StatefulOperatorContext(),
       session.sessionState,
       Some(session.streams.stateStoreCoordinator),
       // We won't check value row in state store since the value StreamingDeduplicateExec.EMPTY_ROW
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala
index 8bba9b8..ba15f50 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow, SortOrder, UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, Partitioning}
 import org.apache.spark.sql.execution.{LimitExec, SparkPlan, UnaryExecNode}
-import org.apache.spark.sql.execution.streaming.state.StateStoreOps
+import org.apache.spark.sql.execution.streaming.state.{StatefulOperatorContext, StateStoreOps}
 import org.apache.spark.sql.streaming.OutputMode
 import org.apache.spark.sql.types.{LongType, NullType, StructField, StructType}
 import org.apache.spark.util.{CompletionIterator, NextIterator}
@@ -52,7 +52,8 @@ case class StreamingGlobalLimitExec(
         getStateInfo,
         keySchema,
         valueSchema,
-        numColsPrefixKey = 0,
+        // FIXME: set event time column!
+        StatefulOperatorContext(),
         session.sessionState,
         Some(session.streams.stateStoreCoordinator)) { (store, iter) =>
       val key = UnsafeProjection.create(keySchema)(new GenericInternalRow(Array[Any](null)))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBenchmark.scala
new file mode 100644
index 0000000..9a83f5c
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBenchmark.scala
@@ -0,0 +1,633 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.benchmark
+
+import java.{util => jutil}
+
+import scala.util.Random
+
+import org.apache.hadoop.conf.Configuration
+import org.rocksdb.RocksDBException
+
+import org.apache.spark.benchmark.Benchmark
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider, StatefulOperatorContext, StateStore, StateStoreConf, StateStoreId, StateStoreProvider}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType, TimestampType}
+import org.apache.spark.util.Utils
+
+/**
+ * Synthetic benchmark for State Store operations.
+ * To run this benchmark:
+ * {{{
+ *   1. without sbt:
+ *      bin/spark-submit --class <this class>
+ *        --jars <spark core test jar>,<spark catalyst test jar> <sql core test jar>
+ *   2. build/sbt "sql/test:runMain <this class>"
+ *   3. generate result:
+ *      SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain <this class>"
+ *      Results will be written to "benchmarks/StateStoreBenchmark-results.txt".
+ * }}}
+ */
+object StateStoreBenchmark extends SqlBasedBenchmark {
+
+  private val numOfRows: Seq[Int] = Seq(10000, 50000, 100000) // Seq(10000, 100000, 1000000)
+
+  // 200%, 100%, 50%, 25%, 10%, 5%, 1%, no update
+  // rate is relative to the number of rows in prev. batch
+  private val updateRates: Seq[Int] = Seq(25, 10, 5) // Seq(200, 100, 50, 25, 10, 5, 1, 0)
+
+  // 100%, 50%, 25%, 10%, 5%, 1%, no evict
+  // rate is relative to the number of rows in prev. batch
+  private val evictRates: Seq[Int] = Seq(100, 50, 25, 10, 5, 1, 0)
+
+  private val keySchema = StructType(
+    Seq(StructField("key1", IntegerType, true), StructField("key2", TimestampType, true)))
+  private val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
+
+  private val keyProjection = UnsafeProjection.create(keySchema)
+  private val valueProjection = UnsafeProjection.create(valueSchema)
+
+  private def runEvictBenchmark(): Unit = {
+    runBenchmark("evict rows") {
+      val numOfRows = Seq(10000) // Seq(1000, 10000, 100000)
+      val numOfTimestamps = Seq(10, 100, 1000)
+      val numOfEvictionRates = Seq(50, 25, 10, 5, 1, 0) // Seq(100, 75, 50, 25, 1, 0)
+
+      numOfRows.foreach { numOfRow =>
+        numOfTimestamps.foreach { numOfTimestamp =>
+          val timestampsInMicros = (0L until numOfTimestamp).map(ts => ts * 1000L).toList
+
+          val testData = constructRandomizedTestData(numOfRow, timestampsInMicros, 0)
+
+          val rocksDBProvider = newRocksDBStateProviderWithEventTimeIdx()
+          val rocksDBStore = rocksDBProvider.getStore(0)
+          updateRows(rocksDBStore, testData)
+
+          val committedVersion = try {
+            rocksDBStore.commit()
+          } catch {
+            case exc: RocksDBException =>
+              // scalastyle:off println
+              System.out.println(s"Exception in RocksDB happen! ${exc.getMessage} / " +
+                s"status: ${exc.getStatus.getState} / ${exc.getStatus.getCodeString}" )
+              exc.printStackTrace()
+              throw exc
+              // scalastyle:on println
+          }
+
+          numOfEvictionRates.foreach { numOfEvictionRate =>
+            val numOfRowsToEvict = numOfRow * numOfEvictionRate / 100
+            // scalastyle:off println
+            System.out.println(s"numOfRowsToEvict: $numOfRowsToEvict / " +
+              s"timestampsInMicros: $timestampsInMicros / " +
+              s"numOfEvictionRate: $numOfEvictionRate / " +
+              s"numOfTimestamp: $numOfTimestamp / " +
+              s"take: ${numOfTimestamp * numOfEvictionRate / 100}")
+
+            // scalastyle:on println
+            val maxTimestampToEvictInMillis = timestampsInMicros
+              .take(numOfTimestamp * numOfEvictionRate / 100)
+              .lastOption.map(_ / 1000).getOrElse(-1L)
+
+            val benchmark = new Benchmark(s"evicting $numOfRowsToEvict rows " +
+              s"(max timestamp to evict in millis: $maxTimestampToEvictInMillis) " +
+              s"from $numOfRow rows with $numOfTimestamp timestamps " +
+              s"(${numOfRow / numOfTimestamp} rows" +
+              s" for the same timestamp)",
+              numOfRow, minNumIters = 1000, output = output)
+
+            benchmark.addTimerCase("RocksDBStateStoreProvider") { timer =>
+              val rocksDBStore = rocksDBProvider.getStore(committedVersion)
+
+              timer.startTiming()
+              evictAsFullScanAndRemove(rocksDBStore, maxTimestampToEvictInMillis)
+              timer.stopTiming()
+
+              rocksDBStore.abort()
+            }
+
+            benchmark.addTimerCase("RocksDBStateStoreProvider with event time idx") { timer =>
+              val rocksDBStore = rocksDBProvider.getStore(committedVersion)
+
+              timer.startTiming()
+              evictAsNewEvictApi(rocksDBStore, maxTimestampToEvictInMillis)
+              timer.stopTiming()
+
+              rocksDBStore.abort()
+            }
+
+            benchmark.run()
+          }
+
+          rocksDBProvider.close()
+        }
+      }
+    }
+  }
+
+  private def runPutBenchmark(): Unit = {
+    runBenchmark("put rows") {
+      val numOfRows = Seq(10000) // Seq(1000, 10000, 100000)
+      val numOfTimestamps = Seq(100, 1000, 10000) // Seq(1, 10, 100, 1000, 10000)
+      numOfRows.foreach { numOfRow =>
+        numOfTimestamps.foreach { numOfTimestamp =>
+          val timestamps = (0L until numOfTimestamp).map(ts => ts * 1000L).toList
+
+          val testData = constructRandomizedTestData(numOfRow, timestamps, 0)
+
+          val rocksDBProvider = newRocksDBStateProvider()
+          val rocksDBWithIdxProvider = newRocksDBStateProviderWithEventTimeIdx()
+
+          val benchmark = new Benchmark(s"putting $numOfRow rows, with $numOfTimestamp " +
+            s"timestamps (${numOfRow / numOfTimestamp} rows for the same timestamp)",
+            numOfRow, minNumIters = 1000, output = output)
+
+          benchmark.addTimerCase("RocksDBStateStoreProvider") { timer =>
+            val rocksDBStore = rocksDBProvider.getStore(0)
+
+            timer.startTiming()
+            updateRows(rocksDBStore, testData)
+            timer.stopTiming()
+
+            rocksDBStore.abort()
+          }
+
+          benchmark.addTimerCase("RocksDBStateStoreProvider with event time idx") { timer =>
+            val rocksDBWithIdxStore = rocksDBWithIdxProvider.getStore(0)
+
+            timer.startTiming()
+            updateRows(rocksDBWithIdxStore, testData)
+            timer.stopTiming()
+
+            rocksDBWithIdxStore.abort()
+          }
+
+          benchmark.run()
+
+          rocksDBProvider.close()
+          rocksDBWithIdxProvider.close()
+        }
+      }
+    }
+  }
+
+  override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
+    // runPutBenchmark()
+    runEvictBenchmark()
+
+    /*
+    val testData = constructRandomizedTestData(numOfRows.max)
+
+    skip("scanning and comparing") {
+      numOfRows.foreach { numOfRow =>
+        val curData = testData.take(numOfRow)
+
+        val inMemoryProvider = newHDFSBackedStateStoreProvider()
+        val inMemoryStore = inMemoryProvider.getStore(0)
+
+        val rocksDBProvider = newRocksDBStateProvider()
+        val rocksDBStore = rocksDBProvider.getStore(0)
+
+        updateRows(inMemoryStore, curData)
+        updateRows(rocksDBStore, curData)
+
+        val newVersionForInMemory = inMemoryStore.commit()
+        val newVersionForRocksDB = rocksDBStore.commit()
+
+        val benchmark = new Benchmark(s"scanning and comparing $numOfRow rows",
+          numOfRow, minNumIters = 1000, output = output)
+
+        benchmark.addTimerCase("HDFSBackedStateStoreProvider") { timer =>
+          val inMemoryStore2 = inMemoryProvider.getStore(newVersionForInMemory)
+
+          timer.startTiming()
+          // NOTE: the latency would be quite similar regardless of the rate of eviction
+          // as we don't remove the actual row, so I simply picked 10 %
+          fullScanAndCompareTimestamp(inMemoryStore2, (numOfRow * 0.1).toInt)
+          timer.stopTiming()
+
+          inMemoryStore2.abort()
+        }
+
+        benchmark.addTimerCase("RocksDBStateStoreProvider") { timer =>
+          val rocksDBStore2 = rocksDBProvider.getStore(newVersionForRocksDB)
+
+          timer.startTiming()
+          // NOTE: the latency would be quite similar regardless of the rate of eviction
+          // as we don't remove the actual row, so I simply picked 10 %
+          fullScanAndCompareTimestamp(rocksDBStore2, (numOfRow * 0.1).toInt)
+          timer.stopTiming()
+
+          rocksDBStore2.abort()
+        }
+
+        benchmark.run()
+
+        inMemoryProvider.close()
+        rocksDBProvider.close()
+      }
+    }
+
+    // runBenchmark("simulate full operations on eviction") {
+    skip("simulate full operations on eviction") {
+      numOfRows.foreach { numOfRow =>
+        val curData = testData.take(numOfRow)
+
+        val inMemoryProvider = newHDFSBackedStateStoreProvider()
+        val inMemoryStore = inMemoryProvider.getStore(0)
+
+        val rocksDBProvider = newRocksDBStateProvider()
+        val rocksDBStore = rocksDBProvider.getStore(0)
+
+        val indexForInMemoryStore = new jutil.concurrent.ConcurrentSkipListMap[
+          Int, jutil.List[UnsafeRow]]()
+        val indexForRocksDBStore = new jutil.concurrent.ConcurrentSkipListMap[
+          Int, jutil.List[UnsafeRow]]()
+
+        updateRowsWithSortedMapIndex(inMemoryStore, indexForInMemoryStore, curData)
+        updateRowsWithSortedMapIndex(rocksDBStore, indexForRocksDBStore, curData)
+
+        assert(indexForInMemoryStore.size() == numOfRow)
+        assert(indexForRocksDBStore.size() == numOfRow)
+
+        val newVersionForInMemory = inMemoryStore.commit()
+        val newVersionForRocksDB = rocksDBStore.commit()
+
+        val rowsToUpdate = constructRandomizedTestData(numOfRow / 100 * updateRates.max,
+          minIdx = numOfRow + 1)
+
+        updateRates.foreach { updateRate =>
+          val numRowsUpdate = numOfRow / 100 * updateRate
+          val curRowsToUpdate = rowsToUpdate.take(numRowsUpdate)
+
+          evictRates.foreach { evictRate =>
+            val maxIdxToEvict = numOfRow / 100 * evictRate
+
+            val benchmark = new Benchmark(s"simulating evict on $numOfRow rows, update " +
+              s"$numRowsUpdate rows ($updateRate %), evict $maxIdxToEvict rows ($evictRate %)",
+              numOfRow, minNumIters = 100, output = output)
+
+            benchmark.addTimerCase("HDFSBackedStateStoreProvider") { timer =>
+              val inMemoryStore2 = inMemoryProvider.getStore(newVersionForInMemory)
+
+              timer.startTiming()
+              updateRows(inMemoryStore2, curRowsToUpdate)
+              evictAsFullScanAndRemove(inMemoryStore2, maxIdxToEvict)
+              timer.stopTiming()
+
+              inMemoryStore2.abort()
+            }
+
+            benchmark.addTimerCase("HDFSBackedStateStoreProvider - sorted map index") { timer =>
+
+              val inMemoryStore2 = inMemoryProvider.getStore(newVersionForInMemory)
+
+              val curIndex = new jutil.concurrent.ConcurrentSkipListMap[Int,
+                jutil.List[UnsafeRow]]()
+              curIndex.putAll(indexForInMemoryStore)
+
+              assert(curIndex.size() == numOfRow)
+
+              timer.startTiming()
+              updateRowsWithSortedMapIndex(inMemoryStore2, curIndex, curRowsToUpdate)
+
+              assert(curIndex.size() == numOfRow + curRowsToUpdate.size)
+
+              evictAsScanSortedMapIndexAndRemove(inMemoryStore2, curIndex, maxIdxToEvict)
+              timer.stopTiming()
+
+              assert(curIndex.size() == numOfRow + curRowsToUpdate.size - maxIdxToEvict)
+
+              curIndex.clear()
+
+              inMemoryStore2.abort()
+            }
+
+            benchmark.run()
+
+            val benchmark2 = new Benchmark(s"simulating evict on $numOfRow rows, update " +
+              s"$numRowsUpdate rows ($updateRate %), evict $maxIdxToEvict rows ($evictRate %)",
+              numOfRow, minNumIters = 100, output = output)
+
+            benchmark2.addTimerCase("RocksDBStateStoreProvider") { timer =>
+              val rocksDBStore2 = rocksDBProvider.getStore(newVersionForRocksDB)
+
+              timer.startTiming()
+              updateRows(rocksDBStore2, curRowsToUpdate)
+              evictAsFullScanAndRemove(rocksDBStore2, maxIdxToEvict)
+              timer.stopTiming()
+
+              rocksDBStore2.abort()
+            }
+
+            benchmark2.addTimerCase("RocksDBStateStoreProvider - sorted map index") { timer =>
+
+              val rocksDBStore2 = rocksDBProvider.getStore(newVersionForRocksDB)
+
+              val curIndex = new jutil.concurrent.ConcurrentSkipListMap[Int,
+                jutil.List[UnsafeRow]]()
+              curIndex.putAll(indexForRocksDBStore)
+
+              assert(curIndex.size() == numOfRow)
+
+              timer.startTiming()
+              updateRowsWithSortedMapIndex(rocksDBStore2, curIndex, curRowsToUpdate)
+
+              assert(curIndex.size() == numOfRow + curRowsToUpdate.size)
+
+              evictAsScanSortedMapIndexAndRemove(rocksDBStore2, curIndex, maxIdxToEvict)
+              timer.stopTiming()
+
+              assert(curIndex.size() == numOfRow + curRowsToUpdate.size - maxIdxToEvict)
+
+              curIndex.clear()
+
+              rocksDBStore2.abort()
+            }
+
+            benchmark2.run()
+          }
+        }
+
+        inMemoryProvider.close()
+        rocksDBProvider.close()
+      }
+    }
+
+    // runBenchmark("simulate full operations on eviction") {
+    skip("simulate full operations on eviction - scannable index") {
+      numOfRows.foreach { numOfRow =>
+        val curData = testData.take(numOfRow)
+
+        val rocksDBProvider = newRocksDBStateProvider()
+        val rocksDBStore = rocksDBProvider.getStore(0)
+
+        val rocksDBWithIdxProvider = newRocksDBStateProviderWithEventTimeIdx()
+        val rocksDBWithIdxStore = rocksDBWithIdxProvider.getStore(0)
+
+        updateRows(rocksDBStore, curData)
+        updateRows(rocksDBWithIdxStore, curData)
+
+        val newVersionForRocksDB = rocksDBStore.commit()
+        val newVersionForRocksDBWithIdx = rocksDBWithIdxStore.commit()
+
+        val rowsToUpdate = constructRandomizedTestData(numOfRow / 100 * updateRates.max,
+          minIdx = numOfRow + 1)
+
+        updateRates.foreach { updateRate =>
+          val numRowsUpdate = numOfRow / 100 * updateRate
+          val curRowsToUpdate = rowsToUpdate.take(numRowsUpdate)
+
+          evictRates.foreach { evictRate =>
+            val maxIdxToEvict = numOfRow / 100 * evictRate
+
+            val benchmark = new Benchmark(s"simulating evict on $numOfRow rows, update " +
+              s"$numRowsUpdate rows ($updateRate %), evict $maxIdxToEvict rows ($evictRate %)",
+              numOfRow, minNumIters = 100, output = output)
+
+            benchmark.addTimerCase("RocksDBStateStoreProvider") { timer =>
+              val rocksDBStore2 = rocksDBProvider.getStore(newVersionForRocksDB)
+
+              timer.startTiming()
+              updateRows(rocksDBStore2, curRowsToUpdate)
+              evictAsFullScanAndRemove(rocksDBStore2, maxIdxToEvict)
+              // evictAsNewEvictApi(rocksDBStore2, maxIdxToEvict)
+              timer.stopTiming()
+
+              rocksDBStore2.abort()
+            }
+
+            benchmark.addTimerCase("RocksDBStateStoreProvider with event time idx") { timer =>
+              val rocksDBWithIdxStore2 = rocksDBWithIdxProvider.getStore(
+                newVersionForRocksDBWithIdx)
+
+              timer.startTiming()
+              updateRows(rocksDBWithIdxStore2, curRowsToUpdate)
+              evictAsNewEvictApi(rocksDBWithIdxStore2, maxIdxToEvict)
+              timer.stopTiming()
+
+              rocksDBWithIdxStore2.abort()
+            }
+
+            benchmark.run()
+          }
+        }
+
+        rocksDBProvider.close()
+        rocksDBWithIdxProvider.close()
+      }
+    }
+
+    runBenchmark("put rows") {
+      numOfRows.foreach { numOfRow =>
+        val curData = testData.take(numOfRow)
+
+        val rocksDBProvider = newRocksDBStateProvider()
+        val rocksDBWithIdxProvider = newRocksDBStateProviderWithEventTimeIdx()
+
+        val benchmark = new Benchmark(s"putting $numOfRow rows",
+          numOfRow, minNumIters = 1000, output = output)
+
+        benchmark.addTimerCase("RocksDBStateStoreProvider") { timer =>
+          val rocksDBStore = rocksDBProvider.getStore(0)
+
+          timer.startTiming()
+          updateRows(rocksDBStore, curData)
+          timer.stopTiming()
+
+          rocksDBStore.abort()
+        }
+
+        benchmark.addTimerCase("RocksDBStateStoreProvider with event time idx") { timer =>
+          val rocksDBWithIdxStore = rocksDBWithIdxProvider.getStore(0)
+
+          timer.startTiming()
+          updateRows(rocksDBWithIdxStore, curData)
+          timer.stopTiming()
+
+          rocksDBWithIdxStore.abort()
+        }
+
+        benchmark.run()
+
+        rocksDBProvider.close()
+        rocksDBWithIdxProvider.close()
+      }
+    }
+     */
+  }
+
+  final def skip(benchmarkName: String)(func: => Any): Unit = {
+    output.foreach(_.write(s"$benchmarkName is skipped".getBytes))
+  }
+
+  private def updateRows(
+      store: StateStore,
+      rows: Seq[(UnsafeRow, UnsafeRow)]): Unit = {
+    rows.foreach { case (key, value) =>
+      store.put(key, value)
+    }
+  }
+
+  private def evictAsFullScanAndRemove(
+      store: StateStore,
+      maxTimestampToEvict: Long): Unit = {
+    store.iterator().foreach { r =>
+      if (r.key.getLong(1) < maxTimestampToEvict) {
+        store.remove(r.key)
+      }
+    }
+  }
+
+  private def evictAsNewEvictApi(
+      store: StateStore,
+      maxTimestampToEvict: Long): Unit = {
+    store.evictOnWatermark(maxTimestampToEvict, pair => {
+      pair.key.getLong(1) < maxTimestampToEvict
+    }).foreach { _ => }
+  }
+
+  private def fullScanAndCompareTimestamp(
+      store: StateStore,
+      maxIdxToEvict: Int): Unit = {
+    var i: Long = 0
+    store.iterator().foreach { r =>
+      if (r.key.getInt(1) < maxIdxToEvict) {
+        // simply to avoid the "if statement" to be no-op
+        i += 1
+      }
+    }
+  }
+
+  private def updateRowsWithSortedMapIndex(
+      store: StateStore,
+      index: jutil.SortedMap[Int, jutil.List[UnsafeRow]],
+      rows: Seq[(UnsafeRow, UnsafeRow)]): Unit = {
+    rows.foreach { case (key, value) =>
+      val idx = key.getInt(1)
+
+      // TODO: rewrite this in atomic way?
+      if (index.containsKey(idx)) {
+        val list = index.get(idx)
+        list.add(key)
+      } else {
+        val list = new jutil.ArrayList[UnsafeRow]()
+        list.add(key)
+        index.put(idx, list)
+      }
+
+      store.put(key, value)
+    }
+  }
+
+  private def evictAsScanSortedMapIndexAndRemove(
+      store: StateStore,
+      index: jutil.SortedMap[Int, jutil.List[UnsafeRow]],
+      maxIdxToEvict: Int): Unit = {
+    val keysToRemove = index.headMap(maxIdxToEvict + 1)
+    val keysToRemoveIter = keysToRemove.entrySet().iterator()
+    while (keysToRemoveIter.hasNext) {
+      val entry = keysToRemoveIter.next()
+      val keys = entry.getValue
+      val keysIter = keys.iterator()
+      while (keysIter.hasNext) {
+        val key = keysIter.next()
+        store.remove(key)
+      }
+      keys.clear()
+      keysToRemoveIter.remove()
+    }
+  }
+
+  // FIXME: should the size of key / value be variables?
+  private def constructTestData(numRows: Int, minIdx: Int = 0): Seq[(UnsafeRow, UnsafeRow)] = {
+    (1 to numRows).map { idx =>
+      val keyRow = new GenericInternalRow(2)
+      keyRow.setInt(0, 1)
+      keyRow.setLong(1, (minIdx + idx) * 1000L) // microseconds
+      val valueRow = new GenericInternalRow(1)
+      valueRow.setInt(0, minIdx + idx)
+
+      val keyUnsafeRow = keyProjection(keyRow).copy()
+      val valueUnsafeRow = valueProjection(valueRow).copy()
+
+      (keyUnsafeRow, valueUnsafeRow)
+    }
+  }
+
+  // This prevents created keys to be in order, which may affect the performance on RocksDB.
+  private def constructRandomizedTestData(
+      numRows: Int,
+      timestamps: List[Long],
+      minIdx: Int = 0): Seq[(UnsafeRow, UnsafeRow)] = {
+    assert(numRows >= timestamps.length)
+    assert(numRows % timestamps.length == 0)
+
+    (1 to numRows).map { idx =>
+      val keyRow = new GenericInternalRow(2)
+      keyRow.setInt(0, Random.nextInt(Int.MaxValue))
+      keyRow.setLong(1, timestamps((minIdx + idx) % timestamps.length)) // microseconds
+      val valueRow = new GenericInternalRow(1)
+      valueRow.setInt(0, minIdx + idx)
+
+      val keyUnsafeRow = keyProjection(keyRow).copy()
+      val valueUnsafeRow = valueProjection(valueRow).copy()
+
+      (keyUnsafeRow, valueUnsafeRow)
+    }
+  }
+
+  private def newHDFSBackedStateStoreProvider(): StateStoreProvider = {
+    val storeId = StateStoreId(newDir(), Random.nextInt(), 0)
+    val provider = new HDFSBackedStateStoreProvider()
+    val sqlConf = new SQLConf()
+    sqlConf.setConfString("spark.sql.streaming.stateStore.compression.codec", "zstd")
+    val storeConf = new StateStoreConf(sqlConf)
+    provider.init(
+      storeId, keySchema, valueSchema, StatefulOperatorContext(),
+      storeConf, new Configuration)
+    provider
+  }
+
+  private def newRocksDBStateProvider(): StateStoreProvider = {
+    val storeId = StateStoreId(newDir(), Random.nextInt(), 0)
+    val provider = new RocksDBStateStoreProvider()
+    val sqlConf = new SQLConf()
+    sqlConf.setConfString("spark.sql.streaming.stateStore.compression.codec", "zstd")
+    val storeConf = new StateStoreConf(sqlConf)
+    provider.init(
+      storeId, keySchema, valueSchema, StatefulOperatorContext(),
+      storeConf, new Configuration)
+    provider
+  }
+
+  private def newRocksDBStateProviderWithEventTimeIdx(): StateStoreProvider = {
+    val storeId = StateStoreId(newDir(), Random.nextInt(), 0)
+    val provider = new RocksDBStateStoreProvider()
+    val sqlConf = new SQLConf()
+    sqlConf.setConfString("spark.sql.streaming.stateStore.compression.codec", "zstd")
+    val storeConf = new StateStoreConf(sqlConf)
+    provider.init(
+      storeId, keySchema, valueSchema, StatefulOperatorContext(eventTimeColIdx = Array(1)),
+      storeConf, new Configuration)
+    provider
+  }
+
+  private def newDir(): String = Utils.createTempDir().toString
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala
index 81f1a3f..4f7c040 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala
@@ -24,7 +24,7 @@ import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
-import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider, StateStore, StateStoreConf, StateStoreId, StateStoreProviderId, StreamingSessionWindowStateManager}
+import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider, StatefulOperatorContext, StateStore, StateStoreConf, StateStoreId, StateStoreProviderId, StreamingSessionWindowStateManager}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.StreamTest
 import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}
@@ -217,9 +217,12 @@ class MergingSortWithSessionWindowStateIteratorSuite extends StreamTest with Bef
         stateFormatVersion)
 
       val storeProviderId = StateStoreProviderId(stateInfo, 0, StateStoreId.DEFAULT_STORE_NAME)
+
+      // FIXME: event time column?
       val store = StateStore.get(
         storeProviderId, manager.getStateKeySchema, manager.getStateValueSchema,
-        manager.getNumColsForPrefixKey, stateInfo.storeVersion, storeConf, new Configuration)
+        StatefulOperatorContext(numColsPrefixKey = manager.getNumColsForPrefixKey),
+        stateInfo.storeVersion, storeConf, new Configuration)
 
       try {
         f(manager, store)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala
index e52ccd0..4b53c0a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala
@@ -50,4 +50,18 @@ class MemoryStateStore extends StateStore() {
   override def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair] = {
     throw new UnsupportedOperationException("Doesn't support prefix scan!")
   }
+
+  /** FIXME: method doc */
+  override def evictOnWatermark(
+      watermarkMs: Long,
+      altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair] = {
+    iterator().filter { pair =>
+      if (altPred.apply(pair)) {
+        remove(pair.key)
+        true
+      } else {
+        false
+      }
+    }
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
index 2d741d3..7374303 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
@@ -24,7 +24,7 @@ import scala.collection.JavaConverters
 import org.scalatest.time.{Minute, Span}
 
 import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper}
-import org.apache.spark.sql.functions.count
+import org.apache.spark.sql.functions.{count, timestamp_seconds, window}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming._
 
@@ -52,6 +52,64 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest {
     }
   }
 
+  test("append mode") {
+    val inputData = MemoryStream[Int]
+    val conf = Map(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName)
+
+    val windowedAggregation = inputData.toDF()
+      .withColumn("eventTime", timestamp_seconds($"value"))
+      .withWatermark("eventTime", "10 seconds")
+      .groupBy(window($"eventTime", "5 seconds") as 'window)
+      .agg(count("*") as 'count)
+      .select($"window".getField("start").cast("long").as[Long], $"count".as[Long])
+
+    testStream(windowedAggregation)(
+      StartStream(additionalConfs = conf),
+      AddData(inputData, 10, 11, 12, 13, 14, 15),
+      CheckNewAnswer(),
+      AddData(inputData, 25),   // Advance watermark to 15 seconds
+      CheckNewAnswer((10, 5)),
+      // assertNumStateRows(2),
+      // assertNumRowsDroppedByWatermark(0),
+      AddData(inputData, 10),   // Should not emit anything as data less than watermark
+      CheckNewAnswer()
+      // assertNumStateRows(2),
+      // assertNumRowsDroppedByWatermark(1)
+    )
+  }
+
+  test("update mode") {
+    val inputData = MemoryStream[Int]
+    val conf = Map(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName)
+
+    val windowedAggregation = inputData.toDF()
+      .withColumn("eventTime", timestamp_seconds($"value"))
+      .withWatermark("eventTime", "10 seconds")
+      .groupBy(window($"eventTime", "5 seconds") as 'window)
+      .agg(count("*") as 'count)
+      .select($"window".getField("start").cast("long").as[Long], $"count".as[Long])
+
+    testStream(windowedAggregation, OutputMode.Update)(
+      StartStream(additionalConfs = conf),
+      AddData(inputData, 10, 11, 12, 13, 14, 15),
+      CheckNewAnswer((10, 5), (15, 1)),
+      AddData(inputData, 25),     // Advance watermark to 15 seconds
+      CheckNewAnswer((25, 1)),
+      // assertNumStateRows(2),
+      // assertNumRowsDroppedByWatermark(0),
+      AddData(inputData, 10, 25), // Ignore 10 as its less than watermark
+      CheckNewAnswer((25, 2)),
+      // assertNumStateRows(2),
+      // assertNumRowsDroppedByWatermark(1),
+      AddData(inputData, 10),     // Should not emit anything as data less than watermark
+      CheckNewAnswer()
+      // assertNumStateRows(2),
+      // assertNumRowsDroppedByWatermark(1)
+    )
+  }
+
   test("SPARK-36236: query progress contains only the expected RocksDB store custom metrics") {
     // fails if any new custom metrics are added to remind the author of API changes
     import testImplicits._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
index c93d0f0..c06a9ec 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
@@ -87,8 +87,9 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid
         queryRunId = UUID.randomUUID, operatorId = 0, storeVersion = 0, numPartitions = 5)
 
       // Create state store in a task and get the RocksDBConf from the instantiated RocksDB instance
+      // FIXME: event time column?
       val rocksDBConfInTask: RocksDBConf = testRDD.mapPartitionsWithStateStore[RocksDBConf](
-        spark.sqlContext, testStateInfo, testSchema, testSchema, 0) {
+        spark.sqlContext, testStateInfo, testSchema, testSchema, StatefulOperatorContext()) {
           (store: StateStore, _: Iterator[String]) =>
             // Use reflection to get RocksDB instance
             val dbInstanceMethod =
@@ -144,7 +145,8 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid
       numColsPrefixKey: Int): RocksDBStateStoreProvider = {
     val provider = new RocksDBStateStoreProvider()
     provider.init(
-      storeId, keySchema, valueSchema, numColsPrefixKey = numColsPrefixKey,
+      storeId, keySchema, valueSchema,
+      StatefulOperatorContext(numColsPrefixKey = numColsPrefixKey),
       new StateStoreConf, new Configuration)
     provider
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
index 6bb8ebe..0e9b19d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
@@ -60,13 +60,13 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
       val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString
       val rdd1 = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("a", 0)))
         .mapPartitionsWithStateStore(spark.sqlContext, operatorStateInfo(path, version = 0),
-          keySchema, valueSchema, numColsPrefixKey = 0)(increment)
+          keySchema, valueSchema, StatefulOperatorContext())(increment)
       assert(rdd1.collect().toSet === Set(("a", 0) -> 2, ("b", 0) -> 1))
 
       // Generate next version of stores
       val rdd2 = makeRDD(spark.sparkContext, Seq(("a", 0), ("c", 0)))
         .mapPartitionsWithStateStore(spark.sqlContext, operatorStateInfo(path, version = 1),
-          keySchema, valueSchema, numColsPrefixKey = 0)(increment)
+          keySchema, valueSchema, StatefulOperatorContext())(increment)
       assert(rdd2.collect().toSet === Set(("a", 0) -> 3, ("b", 0) -> 1, ("c", 0) -> 1))
 
       // Make sure the previous RDD still has the same data.
@@ -84,7 +84,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
       implicit val sqlContext = spark.sqlContext
       makeRDD(spark.sparkContext, Seq(("a", 0))).mapPartitionsWithStateStore(
         sqlContext, operatorStateInfo(path, version = storeVersion),
-        keySchema, valueSchema, numColsPrefixKey = 0)(increment)
+        keySchema, valueSchema, StatefulOperatorContext())(increment)
     }
 
     // Generate RDDs and state store data
@@ -134,19 +134,19 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
 
       val rddOfGets1 = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("c", 0)))
         .mapPartitionsWithStateStore(spark.sqlContext, operatorStateInfo(path, version = 0),
-          keySchema, valueSchema, numColsPrefixKey = 0)(iteratorOfGets)
+          keySchema, valueSchema, StatefulOperatorContext())(iteratorOfGets)
       assert(rddOfGets1.collect().toSet ===
         Set(("a", 0) -> None, ("b", 0) -> None, ("c", 0) -> None))
 
       val rddOfPuts = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("a", 0)))
         .mapPartitionsWithStateStore(sqlContext, operatorStateInfo(path, version = 0),
-          keySchema, valueSchema, numColsPrefixKey = 0)(iteratorOfPuts)
+          keySchema, valueSchema, StatefulOperatorContext())(iteratorOfPuts)
       assert(rddOfPuts.collect().toSet ===
         Set(("a", 0) -> 1, ("a", 0) -> 2, ("b", 0) -> 1))
 
       val rddOfGets2 = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("c", 0)))
         .mapPartitionsWithStateStore(sqlContext, operatorStateInfo(path, version = 1),
-          keySchema, valueSchema, numColsPrefixKey = 0)(iteratorOfGets)
+          keySchema, valueSchema, StatefulOperatorContext())(iteratorOfGets)
       assert(rddOfGets2.collect().toSet ===
         Set(("a", 0) -> Some(2), ("b", 0) -> Some(1), ("c", 0) -> None))
     }
@@ -172,7 +172,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
 
         val rdd = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("a", 0)))
           .mapPartitionsWithStateStore(sqlContext, operatorStateInfo(path, queryRunId = queryRunId),
-          keySchema, valueSchema, numColsPrefixKey = 0)(increment)
+          keySchema, valueSchema, StatefulOperatorContext())(increment)
         require(rdd.partitions.length === 2)
 
         assert(
@@ -200,13 +200,13 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
         val opId = 0
         val rdd1 = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("a", 0)))
           .mapPartitionsWithStateStore(sqlContext, operatorStateInfo(path, version = 0),
-            keySchema, valueSchema, numColsPrefixKey = 0)(increment)
+            keySchema, valueSchema, StatefulOperatorContext())(increment)
         assert(rdd1.collect().toSet === Set(("a", 0) -> 2, ("b", 0) -> 1))
 
         // Generate next version of stores
         val rdd2 = makeRDD(spark.sparkContext, Seq(("a", 0), ("c", 0)))
           .mapPartitionsWithStateStore(sqlContext, operatorStateInfo(path, version = 1),
-            keySchema, valueSchema, numColsPrefixKey = 0)(increment)
+            keySchema, valueSchema, StatefulOperatorContext())(increment)
         assert(rdd2.collect().toSet === Set(("a", 0) -> 3, ("b", 0) -> 1, ("c", 0) -> 1))
 
         // Make sure the previous RDD still has the same data.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index 601b62b..d89c6fed 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -270,8 +270,8 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
 
     def generateStoreVersions(): Unit = {
       for (i <- 1 to 20) {
-        val store = StateStore.get(storeProviderId1, keySchema, valueSchema, numColsPrefixKey = 0,
-          latestStoreVersion, storeConf, hadoopConf)
+        val store = StateStore.get(storeProviderId1, keySchema, valueSchema,
+          StatefulOperatorContext(), latestStoreVersion, storeConf, hadoopConf)
         put(store, "a", 0, i)
         store.commit()
         latestStoreVersion += 1
@@ -324,7 +324,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
           }
 
           // Reload the store and verify
-          StateStore.get(storeProviderId1, keySchema, valueSchema, numColsPrefixKey = 0,
+          StateStore.get(storeProviderId1, keySchema, valueSchema, StatefulOperatorContext(),
             latestStoreVersion, storeConf, hadoopConf)
           assert(StateStore.isLoaded(storeProviderId1))
 
@@ -336,7 +336,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
           }
 
           // Reload the store and verify
-          StateStore.get(storeProviderId1, keySchema, valueSchema, numColsPrefixKey = 0,
+          StateStore.get(storeProviderId1, keySchema, valueSchema, StatefulOperatorContext(),
             latestStoreVersion, storeConf, hadoopConf)
           assert(StateStore.isLoaded(storeProviderId1))
 
@@ -344,7 +344,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
           // then this executor should unload inactive instances immediately.
           coordinatorRef
             .reportActiveInstance(storeProviderId1, "other-host", "other-exec", Seq.empty)
-          StateStore.get(storeProviderId2, keySchema, valueSchema, numColsPrefixKey = 0,
+          StateStore.get(storeProviderId2, keySchema, valueSchema, StatefulOperatorContext(),
             0, storeConf, hadoopConf)
           assert(!StateStore.isLoaded(storeProviderId1))
           assert(StateStore.isLoaded(storeProviderId2))
@@ -453,7 +453,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     // Getting the store should not create temp file
     val store0 = shouldNotCreateTempFile {
       StateStore.get(
-        storeId, keySchema, valueSchema, numColsPrefixKey = 0,
+        storeId, keySchema, valueSchema, StatefulOperatorContext(),
         version = 0, storeConf, hadoopConf)
     }
 
@@ -470,7 +470,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     // Remove should create a temp file
     val store1 = shouldNotCreateTempFile {
       StateStore.get(
-        storeId, keySchema, valueSchema, numColsPrefixKey = 0,
+        storeId, keySchema, valueSchema, StatefulOperatorContext(),
         version = 1, storeConf, hadoopConf)
     }
     remove(store1, _._1 == "a")
@@ -485,7 +485,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     // Commit without any updates should create a delta file
     val store2 = shouldNotCreateTempFile {
       StateStore.get(
-        storeId, keySchema, valueSchema, numColsPrefixKey = 0,
+        storeId, keySchema, valueSchema, StatefulOperatorContext(),
         version = 2, storeConf, hadoopConf)
     }
     store2.commit()
@@ -720,11 +720,12 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
       hadoopConf: Configuration = new Configuration): HDFSBackedStateStoreProvider = {
     val sqlConf = getDefaultSQLConf(minDeltasForSnapshot, numOfVersToRetainInMemory)
     val provider = new HDFSBackedStateStoreProvider()
+    // FIXME: event time column?
     provider.init(
       StateStoreId(dir, opId, partition),
       keySchema,
       valueSchema,
-      numColsPrefixKey = numColsPrefixKey,
+      StatefulOperatorContext(numColsPrefixKey = numColsPrefixKey),
       new StateStoreConf(sqlConf),
       hadoopConf)
     provider
@@ -1027,31 +1028,31 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
       // Verify that trying to get incorrect versions throw errors
       intercept[IllegalArgumentException] {
         StateStore.get(
-          storeId, keySchema, valueSchema, 0, -1, storeConf, hadoopConf)
+          storeId, keySchema, valueSchema, StatefulOperatorContext(), -1, storeConf, hadoopConf)
       }
       assert(!StateStore.isLoaded(storeId)) // version -1 should not attempt to load the store
 
       intercept[IllegalStateException] {
         StateStore.get(
-          storeId, keySchema, valueSchema, 0, 1, storeConf, hadoopConf)
+          storeId, keySchema, valueSchema, StatefulOperatorContext(), 1, storeConf, hadoopConf)
       }
 
       // Increase version of the store and try to get again
       val store0 = StateStore.get(
-        storeId, keySchema, valueSchema, 0, 0, storeConf, hadoopConf)
+        storeId, keySchema, valueSchema, StatefulOperatorContext(), 0, storeConf, hadoopConf)
       assert(store0.version === 0)
       put(store0, "a", 0, 1)
       store0.commit()
 
       val store1 = StateStore.get(
-        storeId, keySchema, valueSchema, 0, 1, storeConf, hadoopConf)
+        storeId, keySchema, valueSchema, StatefulOperatorContext(), 1, storeConf, hadoopConf)
       assert(StateStore.isLoaded(storeId))
       assert(store1.version === 1)
       assert(rowPairsToDataSet(store1.iterator()) === Set(("a", 0) -> 1))
 
       // Verify that you can also load older version
       val store0reloaded = StateStore.get(
-        storeId, keySchema, valueSchema, 0, 0, storeConf, hadoopConf)
+        storeId, keySchema, valueSchema, StatefulOperatorContext(), 0, storeConf, hadoopConf)
       assert(store0reloaded.version === 0)
       assert(rowPairsToDataSet(store0reloaded.iterator()) === Set.empty)
 
@@ -1060,7 +1061,7 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
       assert(!StateStore.isLoaded(storeId))
 
       val store1reloaded = StateStore.get(
-        storeId, keySchema, valueSchema, 0, 1, storeConf, hadoopConf)
+        storeId, keySchema, valueSchema, StatefulOperatorContext(), 1, storeConf, hadoopConf)
       assert(StateStore.isLoaded(storeId))
       assert(store1reloaded.version === 1)
       put(store1reloaded, "a", 0, 2)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManagerSuite.scala
index 096c3bb..dcdbac9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManagerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManagerSuite.scala
@@ -181,9 +181,11 @@ class StreamingSessionWindowStateManagerSuite extends StreamTest with BeforeAndA
         stateFormatVersion)
 
       val storeProviderId = StateStoreProviderId(stateInfo, 0, StateStoreId.DEFAULT_STORE_NAME)
+      // FIXME: event time column?
       val store = StateStore.get(
         storeProviderId, manager.getStateKeySchema, manager.getStateValueSchema,
-        manager.getNumColsForPrefixKey, stateInfo.storeVersion, storeConf, new Configuration)
+        StatefulOperatorContext(numColsPrefixKey = manager.getNumColsForPrefixKey),
+        stateInfo.storeVersion, storeConf, new Configuration)
 
       try {
         f(manager, store)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index e89197b..d376942 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.{LocalLimitExec, SimpleMode, SparkPlan}
 import org.apache.spark.sql.execution.command.ExplainCommand
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.execution.streaming.sources.{ContinuousMemoryStream, MemorySink}
-import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider}
+import org.apache.spark.sql.execution.streaming.state.{StatefulOperatorContext, StateStore, StateStoreConf, StateStoreId, StateStoreProvider}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.sources.StreamSourceProvider
@@ -1418,7 +1418,7 @@ class TestStateStoreProvider extends StateStoreProvider {
       stateStoreId: StateStoreId,
       keySchema: StructType,
       valueSchema: StructType,
-      numColsPrefixKey: Int,
+      operatorContext: StatefulOperatorContext,
       storeConfs: StateStoreConf,
       hadoopConf: Configuration): Unit = {
     throw new Exception("Successfully instantiated")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
index 77334ad..fbf3aae 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
 import org.apache.spark.sql.execution.exchange.Exchange
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.execution.streaming.sources.MemorySink
-import org.apache.spark.sql.execution.streaming.state.{StateSchemaNotCompatible, StateStore, StreamingAggregationStateManager}
+import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider, StateSchemaNotCompatible, StateStore, StreamingAggregationStateManager}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.OutputMode._
@@ -53,29 +53,47 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions {
   import testImplicits._
 
   def executeFuncWithStateVersionSQLConf(
+      providerCls: String,
       stateVersion: Int,
       confPairs: Seq[(String, String)],
       func: => Any): Unit = {
     withSQLConf(confPairs ++
-      Seq(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> stateVersion.toString): _*) {
+      Seq(
+        SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> stateVersion.toString,
+        SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerCls.stripSuffix("$")): _*) {
       func
     }
   }
 
   def testWithAllStateVersions(name: String, confPairs: (String, String)*)
                               (func: => Any): Unit = {
-    for (version <- StreamingAggregationStateManager.supportedVersions) {
-      test(s"$name - state format version $version") {
-        executeFuncWithStateVersionSQLConf(version, confPairs, func)
+    val providers = Seq(
+      // FIXME: testing...
+      // classOf[HDFSBackedStateStoreProvider].getCanonicalName,
+      classOf[RocksDBStateStoreProvider].getCanonicalName)
+
+    for (
+      version <- StreamingAggregationStateManager.supportedVersions;
+      provider <- providers
+    ) yield {
+      test(s"$name - state format version $version / provider: $provider") {
+        executeFuncWithStateVersionSQLConf(provider, version, confPairs, func)
       }
     }
   }
 
   def testQuietlyWithAllStateVersions(name: String, confPairs: (String, String)*)
                                      (func: => Any): Unit = {
-    for (version <- StreamingAggregationStateManager.supportedVersions) {
-      testQuietly(s"$name - state format version $version") {
-        executeFuncWithStateVersionSQLConf(version, confPairs, func)
+    val providers = Seq(
+      classOf[HDFSBackedStateStoreProvider].getCanonicalName,
+      classOf[RocksDBStateStoreProvider].getCanonicalName)
+
+    for (
+      version <- StreamingAggregationStateManager.supportedVersions;
+      provider <- providers
+    ) yield {
+      testQuietly(s"$name - state format version $version / provider: $provider") {
+        executeFuncWithStateVersionSQLConf(provider, version, confPairs, func)
       }
     }
   }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org