You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ka...@apache.org on 2021/11/05 07:42:57 UTC

[spark] branch WIP-optimize-eviction-in-rocksdb-state-store created (now 88187c6)

This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a change to branch WIP-optimize-eviction-in-rocksdb-state-store
in repository https://gitbox.apache.org/repos/asf/spark.git.


      at 88187c6  WIP still need to add e2e test and address FIXME/TODOs

This branch includes the following new commits:

     new 21d4f96  WIP: benchmark test code done
     new 88187c6  WIP still need to add e2e test and address FIXME/TODOs

The 2 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


[spark] 02/02: WIP still need to add e2e test and address FIXME/TODOs

Posted by ka...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch WIP-optimize-eviction-in-rocksdb-state-store
in repository https://gitbox.apache.org/repos/asf/spark.git

commit 88187c64d70a5566bb8f07ce4b133a33e63ce5fc
Author: Jungtaek Lim <ka...@gmail.com>
AuthorDate: Fri Nov 5 16:41:24 2021 +0900

    WIP still need to add e2e test and address FIXME/TODOs
---
 .../sql/execution/streaming/state/RocksDB.scala    | 172 ++++-------------
 .../streaming/state/RocksDBFileManager.scala       |  35 +++-
 .../streaming/state/RocksDBStateEncoder.scala      |  96 +---------
 .../state/RocksDBStateStoreProvider.scala          | 213 +++++++++++++--------
 .../execution/benchmark/StateStoreBenchmark.scala  |  25 ++-
 .../execution/streaming/state/RocksDBSuite.scala   |   8 +-
 6 files changed, 227 insertions(+), 322 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
index eed7827..105a446 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.sql.execution.streaming.state
 
 import java.io.File
-import java.util
 import java.util.Locale
 import javax.annotation.concurrent.GuardedBy
 
@@ -51,12 +50,9 @@ import org.apache.spark.util.{NextIterator, Utils}
  * @param hadoopConf   Hadoop configuration for talking to the remote file system
  * @param loggingId    Id that will be prepended in logs for isolating concurrent RocksDBs
  */
-// FIXME: optionally receiving column families
 class RocksDB(
     dfsRootDir: String,
     val conf: RocksDBConf,
-    // TODO: change "default" to constant
-    columnFamilies: Seq[String] = Seq("default"),
     localRootDir: File = Utils.createTempDir(),
     hadoopConf: Configuration = new Configuration,
     loggingId: String = "") extends Logging {
@@ -69,10 +65,16 @@ class RocksDB(
   private val flushOptions = new FlushOptions().setWaitForFlush(true)  // wait for flush to complete
   private val writeBatch = new WriteBatchWithIndex(true)  // overwrite multiple updates to a key
 
-  private val dbOptions: DBOptions = new DBOptions() // options to open the RocksDB
-  dbOptions.setCreateIfMissing(true)
-  dbOptions.setCreateMissingColumnFamilies(true)
+  private val bloomFilter = new BloomFilter()
+  private val tableFormatConfig = new BlockBasedTableConfig()
+  tableFormatConfig.setBlockSize(conf.blockSizeKB * 1024)
+  tableFormatConfig.setBlockCache(new LRUCache(conf.blockCacheSizeMB * 1024 * 1024))
+  tableFormatConfig.setFilterPolicy(bloomFilter)
+  tableFormatConfig.setFormatVersion(conf.formatVersion)
 
+  private val dbOptions = new Options() // options to open the RocksDB
+  dbOptions.setCreateIfMissing(true)
+  dbOptions.setTableFormatConfig(tableFormatConfig)
   private val dbLogger = createLogger() // for forwarding RocksDB native logs to log4j
   dbOptions.setStatistics(new Statistics())
   private val nativeStats = dbOptions.statistics()
@@ -85,18 +87,18 @@ class RocksDB(
   private val acquireLock = new Object
 
   @volatile private var db: NativeRocksDB = _
-  @volatile private var columnFamilyHandles: util.Map[String, ColumnFamilyHandle] = _
-  @volatile private var defaultColumnFamilyHandle: ColumnFamilyHandle = _
   @volatile private var loadedVersion = -1L   // -1 = nothing valid is loaded
   @volatile private var numKeysOnLoadedVersion = 0L
   @volatile private var numKeysOnWritingVersion = 0L
   @volatile private var fileManagerMetrics = RocksDBFileManagerMetrics.EMPTY_METRICS
+  @volatile private var customMetadataOnLoadedVersion: Map[String, String] = Map.empty
+  @volatile private var customMetadataOnWritingVersion: Map[String, String] = Map.empty
 
   @GuardedBy("acquireLock")
   @volatile private var acquiredThreadInfo: AcquiredThreadInfo = _
 
   private val prefixScanReuseIter =
-    new java.util.concurrent.ConcurrentHashMap[(Long, Int), RocksIterator]()
+    new java.util.concurrent.ConcurrentHashMap[Long, RocksIterator]()
 
   /**
    * Load the given version of data in a native RocksDB instance.
@@ -114,6 +116,8 @@ class RocksDB(
         openDB()
         numKeysOnWritingVersion = metadata.numKeys
         numKeysOnLoadedVersion = metadata.numKeys
+        customMetadataOnLoadedVersion = metadata.customMetadata
+        customMetadataOnWritingVersion = metadata.customMetadata
         loadedVersion = version
         fileManagerMetrics = fileManager.latestLoadCheckpointMetrics
       }
@@ -137,28 +141,7 @@ class RocksDB(
    * @note This will return the last written value even if it was uncommitted.
    */
   def get(key: Array[Byte]): Array[Byte] = {
-    get(defaultColumnFamilyHandle, key)
-  }
-
-  // FIXME: method doc
-  def get(cf: String, key: Array[Byte]): Array[Byte] = {
-    get(findColumnFamilyHandle(cf), key)
-  }
-
-  private def get(cfHandle: ColumnFamilyHandle, key: Array[Byte]): Array[Byte] = {
-    writeBatch.getFromBatchAndDB(db, cfHandle, readOptions, key)
-  }
-
-  def merge(key: Array[Byte], value: Array[Byte]): Unit = {
-    merge(defaultColumnFamilyHandle, key, value)
-  }
-
-  def merge(cf: String, key: Array[Byte], value: Array[Byte]): Unit = {
-    merge(findColumnFamilyHandle(cf), key, value)
-  }
-
-  private def merge(cfHandle: ColumnFamilyHandle, key: Array[Byte], value: Array[Byte]): Unit = {
-    writeBatch.merge(cfHandle, key, value)
+    writeBatch.getFromBatchAndDB(db, readOptions, key)
   }
 
   /**
@@ -166,20 +149,8 @@ class RocksDB(
    * @note This update is not committed to disk until commit() is called.
    */
   def put(key: Array[Byte], value: Array[Byte]): Array[Byte] = {
-    put(defaultColumnFamilyHandle, key, value)
-  }
-
-  // FIXME: method doc
-  def put(cf: String, key: Array[Byte], value: Array[Byte]): Array[Byte] = {
-    put(findColumnFamilyHandle(cf), key, value)
-  }
-
-  private def put(
-      cfHandle: ColumnFamilyHandle,
-      key: Array[Byte],
-      value: Array[Byte]): Array[Byte] = {
-    val oldValue = writeBatch.getFromBatchAndDB(db, cfHandle, readOptions, key)
-    writeBatch.put(cfHandle, key, value)
+    val oldValue = writeBatch.getFromBatchAndDB(db, readOptions, key)
+    writeBatch.put(key, value)
     if (oldValue == null) {
       numKeysOnWritingVersion += 1
     }
@@ -191,18 +162,9 @@ class RocksDB(
    * @note This update is not committed to disk until commit() is called.
    */
   def remove(key: Array[Byte]): Array[Byte] = {
-    remove(defaultColumnFamilyHandle, key)
-  }
-
-  // FIXME: method doc
-  def remove(cf: String, key: Array[Byte]): Array[Byte] = {
-    remove(findColumnFamilyHandle(cf), key)
-  }
-
-  private def remove(cfHandle: ColumnFamilyHandle, key: Array[Byte]): Array[Byte] = {
-    val value = writeBatch.getFromBatchAndDB(db, cfHandle, readOptions, key)
+    val value = writeBatch.getFromBatchAndDB(db, readOptions, key)
     if (value != null) {
-      writeBatch.delete(cfHandle, key)
+      writeBatch.remove(key)
       numKeysOnWritingVersion -= 1
     }
     value
@@ -211,17 +173,8 @@ class RocksDB(
   /**
    * Get an iterator of all committed and uncommitted key-value pairs.
    */
-  def iterator(): NextIterator[ByteArrayPair] = {
-    iterator(defaultColumnFamilyHandle)
-  }
-
-  // FIXME: doc
-  def iterator(cf: String): NextIterator[ByteArrayPair] = {
-    iterator(findColumnFamilyHandle(cf))
-  }
-
-  private def iterator(cfHandle: ColumnFamilyHandle): NextIterator[ByteArrayPair] = {
-    val iter = writeBatch.newIteratorWithBase(cfHandle, db.newIterator(cfHandle))
+  def iterator(): Iterator[ByteArrayPair] = {
+    val iter = writeBatch.newIteratorWithBase(db.newIterator())
     logInfo(s"Getting iterator from version $loadedVersion")
     iter.seekToFirst()
 
@@ -248,20 +201,11 @@ class RocksDB(
   }
 
   def prefixScan(prefix: Array[Byte]): Iterator[ByteArrayPair] = {
-    prefixScan(defaultColumnFamilyHandle, prefix)
-  }
-
-  def prefixScan(cf: String, prefix: Array[Byte]): Iterator[ByteArrayPair] = {
-    prefixScan(findColumnFamilyHandle(cf), prefix)
-  }
-
-  private def prefixScan(
-      cfHandle: ColumnFamilyHandle, prefix: Array[Byte]): Iterator[ByteArrayPair] = {
     val threadId = Thread.currentThread().getId
-    val iter = prefixScanReuseIter.computeIfAbsent((threadId, cfHandle.getID), key => {
-      val it = writeBatch.newIteratorWithBase(cfHandle, db.newIterator(cfHandle))
+    val iter = prefixScanReuseIter.computeIfAbsent(threadId, tid => {
+      val it = writeBatch.newIteratorWithBase(db.newIterator())
       logInfo(s"Getting iterator from version $loadedVersion for prefix scan on " +
-        s"thread ID ${key._1} and column family ID ${key._2}")
+        s"thread ID $tid")
       it
     })
 
@@ -283,12 +227,8 @@ class RocksDB(
     }
   }
 
-  private def findColumnFamilyHandle(cf: String): ColumnFamilyHandle = {
-    val cfHandle = columnFamilyHandles.get(cf)
-    if (cfHandle == null) {
-      throw new IllegalArgumentException(s"Handle for column family $cf is not found")
-    }
-    cfHandle
+  def setCustomMetadata(metadata: Map[String, String]): Unit = {
+    customMetadataOnWritingVersion = metadata
   }
 
   /**
@@ -310,16 +250,11 @@ class RocksDB(
       val writeTimeMs = timeTakenMs { db.write(writeOptions, writeBatch) }
 
       logInfo(s"Flushing updates for $newVersion")
-      val flushTimeMs = timeTakenMs {
-        db.flush(flushOptions,
-          new util.ArrayList[ColumnFamilyHandle](columnFamilyHandles.values()))
-      }
+      val flushTimeMs = timeTakenMs { db.flush(flushOptions) }
 
       val compactTimeMs = if (conf.compactOnCommit) {
         logInfo("Compacting")
-        timeTakenMs {
-          columnFamilyHandles.values().forEach(cfHandle => db.compactRange(cfHandle))
-        }
+        timeTakenMs { db.compactRange() }
       } else 0
 
       logInfo("Pausing background work")
@@ -335,9 +270,11 @@ class RocksDB(
 
       logInfo(s"Syncing checkpoint for $newVersion to DFS")
       val fileSyncTimeMs = timeTakenMs {
-        fileManager.saveCheckpointToDfs(checkpointDir, newVersion, numKeysOnWritingVersion)
+        fileManager.saveCheckpointToDfs(checkpointDir, newVersion, numKeysOnWritingVersion,
+          customMetadataOnWritingVersion.toMap)
       }
       numKeysOnLoadedVersion = numKeysOnWritingVersion
+      customMetadataOnLoadedVersion = customMetadataOnWritingVersion
       loadedVersion = newVersion
       fileManagerMetrics = fileManager.latestSaveCheckpointMetrics
       commitLatencyMs ++= Map(
@@ -352,7 +289,6 @@ class RocksDB(
       loadedVersion
     } catch {
       case t: Throwable =>
-        logWarning(s"ERROR! exc: $t", t)
         loadedVersion = -1  // invalidate loaded version
         throw t
     } finally {
@@ -369,6 +305,7 @@ class RocksDB(
     closePrefixScanIterators()
     writeBatch.clear()
     numKeysOnWritingVersion = numKeysOnLoadedVersion
+    customMetadataOnWritingVersion = customMetadataOnLoadedVersion
     release()
     logInfo(s"Rolled back to $loadedVersion")
   }
@@ -404,6 +341,8 @@ class RocksDB(
   /** Get the latest version available in the DFS */
   def getLatestVersion(): Long = fileManager.getLatestVersion()
 
+  def getCustomMetadata(): Map[String, String] = customMetadataOnWritingVersion
+
   /** Get current instantaneous statistics */
   def metrics: RocksDBMetrics = {
     import HistogramType._
@@ -496,43 +435,12 @@ class RocksDB(
 
   private def openDB(): Unit = {
     assert(db == null)
-
-    val columnFamilyDescriptors = new util.ArrayList[ColumnFamilyDescriptor]()
-    columnFamilies.foreach { cf =>
-      val bloomFilter = new BloomFilter()
-      val tableFormatConfig = new BlockBasedTableConfig()
-      tableFormatConfig.setBlockSize(conf.blockSizeKB * 1024)
-      tableFormatConfig.setBlockCache(new LRUCache(conf.blockCacheSizeMB * 1024 * 1024))
-      tableFormatConfig.setFilterPolicy(bloomFilter)
-      tableFormatConfig.setFormatVersion(conf.formatVersion)
-
-      val columnFamilyOptions = new ColumnFamilyOptions()
-      columnFamilyOptions.setTableFormatConfig(tableFormatConfig)
-      columnFamilyDescriptors.add(new ColumnFamilyDescriptor(cf.getBytes(), columnFamilyOptions))
-    }
-
-    val cfHandles = new util.ArrayList[ColumnFamilyHandle](columnFamilyDescriptors.size())
-    db = NativeRocksDB.open(dbOptions, workingDir.toString, columnFamilyDescriptors,
-      cfHandles)
-
-    columnFamilyHandles = new util.HashMap[String, ColumnFamilyHandle]()
-    columnFamilies.indices.foreach { idx =>
-      columnFamilyHandles.put(columnFamilies(idx), cfHandles.get(idx))
-    }
-
-    // FIXME: constant
-    defaultColumnFamilyHandle = columnFamilyHandles.get("default")
-
+    db = NativeRocksDB.open(dbOptions, workingDir.toString)
     logInfo(s"Opened DB with conf ${conf}")
   }
 
   private def closeDB(): Unit = {
     if (db != null) {
-      columnFamilyHandles.entrySet().forEach(pair => db.destroyColumnFamilyHandle(pair.getValue))
-      columnFamilyHandles.clear()
-      columnFamilyHandles = null
-      defaultColumnFamilyHandle = null
-
       db.close()
       db = null
     }
@@ -546,17 +454,10 @@ class RocksDB(
         // Warn is mapped to info because RocksDB warn is too verbose
         // (e.g. dumps non-warning stuff like stats)
         val loggingFunc: ( => String) => Unit = infoLogLevel match {
-          /*
           case InfoLogLevel.FATAL_LEVEL | InfoLogLevel.ERROR_LEVEL => logError(_)
           case InfoLogLevel.WARN_LEVEL | InfoLogLevel.INFO_LEVEL => logInfo(_)
           case InfoLogLevel.DEBUG_LEVEL => logDebug(_)
           case _ => logTrace(_)
-           */
-          case InfoLogLevel.FATAL_LEVEL | InfoLogLevel.ERROR_LEVEL => logError(_)
-          case InfoLogLevel.WARN_LEVEL => logWarning(_)
-          case InfoLogLevel.INFO_LEVEL => logInfo(_)
-          case InfoLogLevel.DEBUG_LEVEL => logDebug(_)
-          case _ => logTrace(_)
         }
         loggingFunc(s"[NativeRocksDB-${infoLogLevel.getValue}] $logMsg")
       }
@@ -702,7 +603,7 @@ object RocksDBMetrics {
 
 /** Class to wrap RocksDB's native histogram */
 case class RocksDBNativeHistogram(
-    sum: Long, avg: Double, stddev: Double, median: Double, p95: Double, p99: Double, count: Long) {
+  sum: Long, avg: Double, stddev: Double, median: Double, p95: Double, p99: Double, count: Long) {
   def json: String = Serialization.write(this)(RocksDBMetrics.format)
 }
 
@@ -733,4 +634,3 @@ case class AcquiredThreadInfo() {
     s"[ThreadId: ${threadRef.get.map(_.getId)}$taskStr]"
   }
 }
-
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
index 23cdbd0..567f916 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
@@ -152,11 +152,15 @@ class RocksDBFileManager(
   def latestSaveCheckpointMetrics: RocksDBFileManagerMetrics = saveCheckpointMetrics
 
   /** Save all the files in given local checkpoint directory as a committed version in DFS */
-  def saveCheckpointToDfs(checkpointDir: File, version: Long, numKeys: Long): Unit = {
+  def saveCheckpointToDfs(
+      checkpointDir: File,
+      version: Long,
+      numKeys: Long,
+      customMetadata: Map[String, String] = Map.empty): Unit = {
     logFilesInDir(checkpointDir, s"Saving checkpoint files for version $version")
     val (localImmutableFiles, localOtherFiles) = listRocksDBFiles(checkpointDir)
     val rocksDBFiles = saveImmutableFilesToDfs(version, localImmutableFiles)
-    val metadata = RocksDBCheckpointMetadata(rocksDBFiles, numKeys)
+    val metadata = RocksDBCheckpointMetadata(rocksDBFiles, numKeys, customMetadata)
     val metadataFile = localMetadataFile(checkpointDir)
     metadata.writeToFile(metadataFile)
     logInfo(s"Written metadata for version $version:\n${metadata.prettyJson}")
@@ -184,7 +188,7 @@ class RocksDBFileManager(
     val metadata = if (version == 0) {
       if (localDir.exists) Utils.deleteRecursively(localDir)
       localDir.mkdirs()
-      RocksDBCheckpointMetadata(Seq.empty, 0)
+      RocksDBCheckpointMetadata(Seq.empty, 0, Map.empty)
     } else {
       // Delete all non-immutable files in local dir, and unzip new ones from DFS commit file
       listRocksDBFiles(localDir)._2.foreach(_.delete())
@@ -540,12 +544,20 @@ object RocksDBFileManagerMetrics {
 case class RocksDBCheckpointMetadata(
     sstFiles: Seq[RocksDBSstFile],
     logFiles: Seq[RocksDBLogFile],
-    numKeys: Long) {
+    numKeys: Long,
+    customMetadata: Map[String, String]) {
   import RocksDBCheckpointMetadata._
 
   def json: String = {
-    // We turn this field into a null to avoid write a empty logFiles field in the json.
-    val nullified = if (logFiles.isEmpty) this.copy(logFiles = null) else this
+    // We turn the field into a null to avoid write below fields in the json if they are empty:
+    // - logFiles
+    // - customMetadata
+    val nullified = {
+      var cur = this
+      cur = if (logFiles.isEmpty) cur.copy(logFiles = null) else cur
+      cur = if (customMetadata.isEmpty) cur.copy(customMetadata = null) else cur
+      cur
+    }
     mapper.writeValueAsString(nullified)
   }
 
@@ -593,11 +605,18 @@ object RocksDBCheckpointMetadata {
     }
   }
 
-  def apply(rocksDBFiles: Seq[RocksDBImmutableFile], numKeys: Long): RocksDBCheckpointMetadata = {
+  def apply(
+      rocksDBFiles: Seq[RocksDBImmutableFile],
+      numKeys: Long): RocksDBCheckpointMetadata = apply(rocksDBFiles, numKeys, Map.empty)
+
+  def apply(
+      rocksDBFiles: Seq[RocksDBImmutableFile],
+      numKeys: Long,
+      customMetadata: Map[String, String]): RocksDBCheckpointMetadata = {
     val sstFiles = rocksDBFiles.collect { case file: RocksDBSstFile => file }
     val logFiles = rocksDBFiles.collect { case file: RocksDBLogFile => file }
 
-    RocksDBCheckpointMetadata(sstFiles, logFiles, numKeys)
+    RocksDBCheckpointMetadata(sstFiles, logFiles, numKeys, customMetadata)
   }
 }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
index 323826d..84e9a8d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
@@ -22,7 +22,7 @@ import java.nio.ByteOrder
 
 import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow, UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{STATE_ENCODING_NUM_VERSION_BYTES, STATE_ENCODING_VERSION}
-import org.apache.spark.sql.types.{StructField, StructType, TimestampType}
+import org.apache.spark.sql.types.{StructField, StructType}
 import org.apache.spark.unsafe.Platform
 
 sealed trait RocksDBStateEncoder {
@@ -30,11 +30,6 @@ sealed trait RocksDBStateEncoder {
   def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte]
   def extractPrefixKey(key: UnsafeRow): UnsafeRow
 
-  def supportEventTimeIndex: Boolean
-  def extractEventTime(key: UnsafeRow): Long
-  def encodeEventTimeIndexKey(timestamp: Long, encodedKey: Array[Byte]): Array[Byte]
-  def decodeEventTimeIndexKey(eventTimeBytes: Array[Byte]): (Long, Array[Byte])
-
   def encodeKey(row: UnsafeRow): Array[Byte]
   def encodeValue(row: UnsafeRow): Array[Byte]
 
@@ -47,13 +42,11 @@ object RocksDBStateEncoder {
   def getEncoder(
       keySchema: StructType,
       valueSchema: StructType,
-      numColsPrefixKey: Int,
-      eventTimeColIdx: Array[Int]): RocksDBStateEncoder = {
+      numColsPrefixKey: Int): RocksDBStateEncoder = {
     if (numColsPrefixKey > 0) {
-      // FIXME: need to deal with prefix case as well
       new PrefixKeyScanStateEncoder(keySchema, valueSchema, numColsPrefixKey)
     } else {
-      new NoPrefixKeyStateEncoder(keySchema, valueSchema, eventTimeColIdx)
+      new NoPrefixKeyStateEncoder(keySchema, valueSchema)
     }
   }
 
@@ -228,23 +221,6 @@ class PrefixKeyScanStateEncoder(
   }
 
   override def supportPrefixKeyScan: Boolean = true
-
-  override def supportEventTimeIndex: Boolean = false
-
-  // FIXME: fix me!
-  def extractEventTime(key: UnsafeRow): Long = {
-    throw new IllegalStateException("This encoder doesn't support event time index!")
-  }
-
-  // FIXME: fix me!
-  def encodeEventTimeIndexKey(timestamp: Long, encodedKey: Array[Byte]): Array[Byte] = {
-    throw new IllegalStateException("This encoder doesn't support event time index!")
-  }
-
-  // FIXME: fix me!
-  def decodeEventTimeIndexKey(eventTimeBytes: Array[Byte]): (Long, Array[Byte]) = {
-    throw new IllegalStateException("This encoder doesn't support event time index!")
-  }
 }
 
 /**
@@ -259,8 +235,7 @@ class PrefixKeyScanStateEncoder(
  */
 class NoPrefixKeyStateEncoder(
     keySchema: StructType,
-    valueSchema: StructType,
-    eventTimeColIdx: Array[Int]) extends RocksDBStateEncoder {
+    valueSchema: StructType) extends RocksDBStateEncoder {
 
   import RocksDBStateEncoder._
 
@@ -269,32 +244,6 @@ class NoPrefixKeyStateEncoder(
   private val valueRow = new UnsafeRow(valueSchema.size)
   private val rowTuple = new UnsafeRowPair()
 
-  validateColumnTypeOnEventTimeColumn()
-
-  private def validateColumnTypeOnEventTimeColumn(): Unit = {
-    if (eventTimeColIdx.nonEmpty) {
-      var curSchema: StructType = keySchema
-      eventTimeColIdx.dropRight(1).foreach { idx =>
-        curSchema(idx).dataType match {
-          case stType: StructType =>
-            curSchema = stType
-          case _ =>
-            // FIXME: better error message
-            throw new IllegalStateException("event time column is not properly specified! " +
-              s"index: ${eventTimeColIdx.mkString("(", ", ", ")")} / key schema: $keySchema")
-        }
-      }
-
-      curSchema(eventTimeColIdx.last).dataType match {
-        case _: TimestampType =>
-        case _ =>
-          // FIXME: better error message
-          throw new IllegalStateException("event time column is not properly specified! " +
-            s"index: ${eventTimeColIdx.mkString("(", ", ", ")")} / key schema: $keySchema")
-      }
-    }
-  }
-
   override def encodeKey(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)
 
   override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)
@@ -337,41 +286,4 @@ class NoPrefixKeyStateEncoder(
   override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = {
     throw new IllegalStateException("This encoder doesn't support prefix key!")
   }
-
-  override def supportEventTimeIndex: Boolean = eventTimeColIdx.nonEmpty
-
-  override def extractEventTime(key: UnsafeRow): Long = {
-    var curRow: UnsafeRow = key
-    var curSchema: StructType = keySchema
-
-    eventTimeColIdx.dropRight(1).foreach { idx =>
-      // validation is done in initialization phase
-      curSchema = curSchema(idx).dataType.asInstanceOf[StructType]
-      curRow = curRow.getStruct(idx, curSchema.length)
-    }
-
-    curRow.getLong(eventTimeColIdx.last) / 1000
-  }
-
-  override def encodeEventTimeIndexKey(timestamp: Long, encodedKey: Array[Byte]): Array[Byte] = {
-    val newKey = new Array[Byte](8 + encodedKey.length)
-
-    Platform.putLong(newKey, Platform.BYTE_ARRAY_OFFSET,
-      BinarySortable.encodeToBinarySortableLong(timestamp))
-    Platform.copyMemory(encodedKey, Platform.BYTE_ARRAY_OFFSET, newKey,
-      Platform.BYTE_ARRAY_OFFSET + 8, encodedKey.length)
-
-    newKey
-  }
-
-  override def decodeEventTimeIndexKey(eventTimeBytes: Array[Byte]): (Long, Array[Byte]) = {
-    val encoded = Platform.getLong(eventTimeBytes, Platform.BYTE_ARRAY_OFFSET)
-    val timestamp = BinarySortable.decodeBinarySortableLong(encoded)
-
-    val key = new Array[Byte](eventTimeBytes.length - 8)
-    Platform.copyMemory(eventTimeBytes, Platform.BYTE_ARRAY_OFFSET + 8,
-      key, Platform.BYTE_ARRAY_OFFSET, eventTimeBytes.length - 8)
-
-    (timestamp, key)
-  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
index 1d66220..d7c2e0f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
@@ -24,15 +24,14 @@ import org.apache.hadoop.conf.Configuration
 import org.apache.spark.{SparkConf, SparkEnv}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.types.StructType
-import org.apache.spark.unsafe.Platform
+import org.apache.spark.sql.types.{StructType, TimestampType}
 import org.apache.spark.util.{NextIterator, Utils}
 
 private[sql] class RocksDBStateStoreProvider
   extends StateStoreProvider with Logging with Closeable {
   import RocksDBStateStoreProvider._
 
-  class RocksDBStateStore(lastVersion: Long) extends StateStore {
+  class RocksDBStateStore(lastVersion: Long, eventTimeColIdx: Array[Int]) extends StateStore {
     /** Trait and classes representing the internal state of the store */
     trait STATE
     case object UPDATING extends STATE
@@ -42,6 +41,40 @@ private[sql] class RocksDBStateStoreProvider
     @volatile private var state: STATE = UPDATING
     @volatile private var isValidated = false
 
+    private val supportEventTimeIndex: Boolean = eventTimeColIdx.nonEmpty
+
+    if (supportEventTimeIndex) {
+      validateColumnTypeOnEventTimeColumn()
+    }
+
+    private def validateColumnTypeOnEventTimeColumn(): Unit = {
+      require(eventTimeColIdx.nonEmpty)
+
+      var curSchema: StructType = keySchema
+      eventTimeColIdx.dropRight(1).foreach { idx =>
+        curSchema(idx).dataType match {
+          case stType: StructType =>
+            curSchema = stType
+          case _ =>
+            // FIXME: better error message
+            throw new IllegalStateException("event time column is not properly specified! " +
+              s"index: ${eventTimeColIdx.mkString("(", ", ", ")")} / key schema: $keySchema")
+        }
+      }
+
+      curSchema(eventTimeColIdx.last).dataType match {
+        case _: TimestampType =>
+        case _ =>
+          // FIXME: better error message
+          throw new IllegalStateException("event time column is not properly specified! " +
+            s"index: ${eventTimeColIdx.mkString("(", ", ", ")")} / key schema: $keySchema")
+      }
+    }
+
+    private var lowestEventTime: Long = Option(rocksDB.getCustomMetadata())
+      .flatMap(_.get(METADATA_KEY_LOWEST_EVENT_TIME).map(_.toLong))
+      .getOrElse(INVALID_LOWEST_EVENT_TIME_VALUE)
+
     override def id: StateStoreId = RocksDBStateStoreProvider.this.stateStoreId
 
     override def version: Long = lastVersion
@@ -66,19 +99,29 @@ private[sql] class RocksDBStateStoreProvider
       val encodedValue = encoder.encodeValue(value)
       rocksDB.put(encodedKey, encodedValue)
 
-      if (encoder.supportEventTimeIndex) {
-        val timestamp = encoder.extractEventTime(key)
-        val tsKey = encoder.encodeEventTimeIndexKey(timestamp, encodedKey)
-
-        rocksDB.put(CF_EVENT_TIME_INDEX, tsKey, Array.empty)
+      if (supportEventTimeIndex) {
+        val eventTimeValue = extractEventTime(key)
+        if (lowestEventTime != INVALID_LOWEST_EVENT_TIME_VALUE
+          && lowestEventTime > eventTimeValue) {
+          lowestEventTime = eventTimeValue
+        }
       }
     }
 
     override def remove(key: UnsafeRow): Unit = {
       verify(state == UPDATING, "Cannot remove after already committed or aborted")
       verify(key != null, "Key cannot be null")
-      // FIXME: this should reflect the index
+
       rocksDB.remove(encoder.encodeKey(key))
+      if (supportEventTimeIndex) {
+        val eventTimeValue = extractEventTime(key)
+        if (lowestEventTime == eventTimeValue) {
+          // We can't track the next lowest event time without scanning entire keys.
+          // Mark the lowest event time value as invalid, so that scan happens in evict phase and
+          // the value is correctly updated later.
+          lowestEventTime = INVALID_LOWEST_EVENT_TIME_VALUE
+        }
+      }
     }
 
     override def iterator(): Iterator[UnsafeRowPair] = {
@@ -100,8 +143,89 @@ private[sql] class RocksDBStateStoreProvider
       rocksDB.prefixScan(prefix).map(kv => encoder.decode(kv))
     }
 
+    /** FIXME: method doc */
+    override def evictOnWatermark(
+      watermarkMs: Long,
+      altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair] = {
+      if (supportEventTimeIndex) {
+        // convert lowestEventTime to milliseconds, and compare to watermarkMs
+        // retract 1 ms to avoid edge-case on conversion from microseconds to milliseconds
+        if (lowestEventTime != INVALID_LOWEST_EVENT_TIME_VALUE
+          && ((lowestEventTime / 1000) - 1 > watermarkMs)) {
+          Iterator.empty
+        } else {
+          // start with invalidating the lowest event time
+          lowestEventTime = INVALID_LOWEST_EVENT_TIME_VALUE
+
+          new NextIterator[UnsafeRowPair] {
+            private val iter = rocksDB.iterator()
+
+            // here we use Long.MaxValue as invalid value
+            private var lowestEventTimeInIter = Long.MaxValue
+
+            override protected def getNext(): UnsafeRowPair = {
+              var result: UnsafeRowPair = null
+              while (result == null && iter.hasNext) {
+                val kv = iter.next()
+                val rowPair = encoder.decode(kv)
+                if (altPred(rowPair)) {
+                  rocksDB.remove(kv.key)
+                  result = rowPair
+                } else {
+                  val eventTime = extractEventTime(rowPair.key)
+                  if (lowestEventTimeInIter > eventTime) {
+                    lowestEventTimeInIter = eventTime
+                  }
+                }
+              }
+
+              if (result == null) {
+                finished = true
+                null
+              } else {
+                result
+              }
+            }
+
+            override protected def close(): Unit = {
+              if (lowestEventTimeInIter != Long.MaxValue) {
+                lowestEventTime = lowestEventTimeInIter
+              }
+            }
+          }
+        }
+      } else {
+        rocksDB.iterator().flatMap { kv =>
+          val rowPair = encoder.decode(kv)
+          if (altPred(rowPair)) {
+            rocksDB.remove(kv.key)
+            Some(rowPair)
+          } else {
+            None
+          }
+        }
+      }
+    }
+
+    private def extractEventTime(key: UnsafeRow): Long = {
+      var curRow: UnsafeRow = key
+      var curSchema: StructType = keySchema
+
+      eventTimeColIdx.dropRight(1).foreach { idx =>
+        // validation is done in initialization phase
+        curSchema = curSchema(idx).dataType.asInstanceOf[StructType]
+        curRow = curRow.getStruct(idx, curSchema.length)
+      }
+
+      curRow.getLong(eventTimeColIdx.last)
+    }
+
     override def commit(): Long = synchronized {
       verify(state == UPDATING, "Cannot commit after already committed or aborted")
+
+      // set the metadata to RocksDB instance so that it can be committed as well
+      rocksDB.setCustomMetadata(Map(METADATA_KEY_LOWEST_EVENT_TIME -> lowestEventTime.toString))
+
       val newVersion = rocksDB.commit()
       state = COMMITTED
       logInfo(s"Committed $newVersion for $id")
@@ -173,68 +297,6 @@ private[sql] class RocksDBStateStoreProvider
 
     /** Return the [[RocksDB]] instance in this store. This is exposed mainly for testing. */
     def dbInstance(): RocksDB = rocksDB
-
-    /** FIXME: method doc */
-    override def evictOnWatermark(
-        watermarkMs: Long,
-        altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair] = {
-
-      if (encoder.supportEventTimeIndex) {
-        val kv = new ByteArrayPair()
-
-        // FIXME: DEBUG
-        // logWarning(s"DEBUG: start iterating event time index, watermark $watermarkMs")
-
-        new NextIterator[UnsafeRowPair] {
-          private val iter = rocksDB.iterator(CF_EVENT_TIME_INDEX)
-          override protected def getNext(): UnsafeRowPair = {
-            if (iter.hasNext) {
-              val pair = iter.next()
-
-              val encodedTs = Platform.getLong(pair.key, Platform.BYTE_ARRAY_OFFSET)
-              val decodedTs = BinarySortable.decodeBinarySortableLong(encodedTs)
-
-              // FIXME: DEBUG
-              // logWarning(s"DEBUG: decoded TS: $decodedTs")
-
-              if (decodedTs > watermarkMs) {
-                finished = true
-                null
-              } else {
-                // FIXME: can we leverage deleteRange to bulk delete on index?
-                rocksDB.remove(CF_EVENT_TIME_INDEX, pair.key)
-                val (_, encodedKey) = encoder.decodeEventTimeIndexKey(pair.key)
-                val value = rocksDB.get(encodedKey)
-                if (value == null) {
-                  throw new IllegalStateException("Event time index has been broken!")
-                }
-                kv.set(encodedKey, value)
-                val rowPair = encoder.decode(kv)
-                rocksDB.remove(encodedKey)
-                rowPair
-              }
-            } else {
-              finished = true
-              null
-            }
-          }
-
-          override protected def close(): Unit = {
-            iter.closeIfNeeded()
-          }
-        }
-      } else {
-        rocksDB.iterator().flatMap { kv =>
-          val rowPair = encoder.decode(kv)
-          if (altPred(rowPair)) {
-            rocksDB.remove(kv.key)
-            Some(rowPair)
-          } else {
-            None
-          }
-        }
-      }
-    }
   }
 
   override def init(
@@ -257,7 +319,7 @@ private[sql] class RocksDBStateStoreProvider
     this.operatorContext = operatorContext
 
     this.encoder = RocksDBStateEncoder.getEncoder(keySchema, valueSchema,
-      operatorContext.numColsPrefixKey, operatorContext.eventTimeColIdx)
+      operatorContext.numColsPrefixKey)
 
     rocksDB // lazy initialization
   }
@@ -267,7 +329,7 @@ private[sql] class RocksDBStateStoreProvider
   override def getStore(version: Long): StateStore = {
     require(version >= 0, "Version cannot be less than 0")
     rocksDB.load(version)
-    new RocksDBStateStore(version)
+    new RocksDBStateStore(version, operatorContext.eventTimeColIdx)
   }
 
   override def doMaintenance(): Unit = {
@@ -298,7 +360,6 @@ private[sql] class RocksDBStateStoreProvider
     val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
     val localRootDir = Utils.createTempDir(Utils.getLocalDir(sparkConf), storeIdStr)
     new RocksDB(dfsRootDir, RocksDBConf(storeConf),
-      columnFamilies = Seq("default", RocksDBStateStoreProvider.CF_EVENT_TIME_INDEX),
       localRootDir = localRootDir,
       hadoopConf = hadoopConf, loggingId = storeIdStr)
   }
@@ -315,8 +376,8 @@ object RocksDBStateStoreProvider {
   val STATE_ENCODING_NUM_VERSION_BYTES = 1
   val STATE_ENCODING_VERSION: Byte = 0
 
-  // reserved column families
-  val CF_EVENT_TIME_INDEX: String = "__event_time_idx"
+  val INVALID_LOWEST_EVENT_TIME_VALUE: Long = Long.MinValue
+  val METADATA_KEY_LOWEST_EVENT_TIME: String = "lowestEventTimeInState"
 
   // Native operation latencies report as latency in microseconds
   // as SQLMetrics support millis. Convert the value to millis
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBenchmark.scala
index 9a83f5c..12ac517 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBenchmark.scala
@@ -66,7 +66,7 @@ object StateStoreBenchmark extends SqlBasedBenchmark {
   private def runEvictBenchmark(): Unit = {
     runBenchmark("evict rows") {
       val numOfRows = Seq(10000) // Seq(1000, 10000, 100000)
-      val numOfTimestamps = Seq(10, 100, 1000)
+      val numOfTimestamps = Seq(100, 1000)
       val numOfEvictionRates = Seq(50, 25, 10, 5, 1, 0) // Seq(100, 75, 50, 25, 1, 0)
 
       numOfRows.foreach { numOfRow =>
@@ -116,7 +116,7 @@ object StateStoreBenchmark extends SqlBasedBenchmark {
               val rocksDBStore = rocksDBProvider.getStore(committedVersion)
 
               timer.startTiming()
-              evictAsFullScanAndRemove(rocksDBStore, maxTimestampToEvictInMillis)
+              evictAsFullScanAndRemove(rocksDBStore, maxTimestampToEvictInMillis, numOfRowsToEvict)
               timer.stopTiming()
 
               rocksDBStore.abort()
@@ -126,7 +126,7 @@ object StateStoreBenchmark extends SqlBasedBenchmark {
               val rocksDBStore = rocksDBProvider.getStore(committedVersion)
 
               timer.startTiming()
-              evictAsNewEvictApi(rocksDBStore, maxTimestampToEvictInMillis)
+              evictAsNewEvictApi(rocksDBStore, maxTimestampToEvictInMillis, numOfRowsToEvict)
               timer.stopTiming()
 
               rocksDBStore.abort()
@@ -487,20 +487,29 @@ object StateStoreBenchmark extends SqlBasedBenchmark {
 
   private def evictAsFullScanAndRemove(
       store: StateStore,
-      maxTimestampToEvict: Long): Unit = {
+      maxTimestampToEvict: Long,
+      expectedNumOfRows: Long): Unit = {
+    var removedRows: Long = 0
     store.iterator().foreach { r =>
-      if (r.key.getLong(1) < maxTimestampToEvict) {
+      if (r.key.getLong(1) / 1000 <= maxTimestampToEvict) {
         store.remove(r.key)
+        removedRows += 1
       }
     }
+    assert(removedRows == expectedNumOfRows,
+      s"expected: $expectedNumOfRows actual: $removedRows")
   }
 
   private def evictAsNewEvictApi(
       store: StateStore,
-      maxTimestampToEvict: Long): Unit = {
+      maxTimestampToEvict: Long,
+      expectedNumOfRows: Long): Unit = {
+    var removedRows: Long = 0
     store.evictOnWatermark(maxTimestampToEvict, pair => {
-      pair.key.getLong(1) < maxTimestampToEvict
-    }).foreach { _ => }
+      pair.key.getLong(1) / 1000 <= maxTimestampToEvict
+    }).foreach { _ => removedRows += 1 }
+    assert(removedRows == expectedNumOfRows,
+      s"expected: $expectedNumOfRows actual: $removedRows")
   }
 
   private def fullScanAndCompareTimestamp(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
index 1ee2748..31e49ae 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
@@ -355,14 +355,18 @@ class RocksDBSuite extends SparkFunSuite {
       RocksDBCheckpointMetadata(Seq.empty, 0L),
       """{"sstFiles":[],"numKeys":0}"""
     )
-    // shouldn't include the "logFiles" field in json when it's empty
+    // shouldn't include the "logFiles" & "customMetadata" field in json when it's empty
     checkJsonRoundtrip(
       RocksDBCheckpointMetadata(sstFiles, 12345678901234L),
       """{"sstFiles":[{"localFileName":"00001.sst","dfsSstFileName":"00001-uuid.sst","sizeBytes":12345678901234}],"numKeys":12345678901234}"""
     )
+    // shouldn't include the "customMetadata" field in json when it's empty
     checkJsonRoundtrip(
-      RocksDBCheckpointMetadata(sstFiles, logFiles, 12345678901234L),
+      RocksDBCheckpointMetadata(sstFiles, logFiles, 12345678901234L, Map.empty),
       """{"sstFiles":[{"localFileName":"00001.sst","dfsSstFileName":"00001-uuid.sst","sizeBytes":12345678901234}],"logFiles":[{"localFileName":"00001.log","dfsLogFileName":"00001-uuid.log","sizeBytes":12345678901234}],"numKeys":12345678901234}""")
+
+    // FIXME: test customMetadata here
+
     // scalastyle:on line.size.limit
   }
 

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


[spark] 01/02: WIP: benchmark test code done

Posted by ka...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch WIP-optimize-eviction-in-rocksdb-state-store
in repository https://gitbox.apache.org/repos/asf/spark.git

commit 21d4f96c6c7a56acce30d485a0747e1d62d7ad27
Author: Jungtaek Lim <ka...@gmail.com>
AuthorDate: Thu Sep 30 11:53:39 2021 +0900

    WIP: benchmark test code done
---
 .../streaming/FlatMapGroupsWithStateExec.scala     |   7 +-
 .../state/HDFSBackedStateStoreProvider.scala       |  42 +-
 .../sql/execution/streaming/state/RocksDB.scala    | 158 ++++-
 .../streaming/state/RocksDBStateEncoder.scala      | 135 ++++-
 .../state/RocksDBStateStoreProvider.scala          | 100 +++-
 .../sql/execution/streaming/state/StateStore.scala |  40 +-
 .../execution/streaming/state/StateStoreRDD.scala  |   8 +-
 .../state/StreamingAggregationStateManager.scala   |  23 +
 .../state/SymmetricHashJoinStateManager.scala      |   3 +-
 .../sql/execution/streaming/state/package.scala    |  12 +-
 .../execution/streaming/statefulOperators.scala    |  62 +-
 .../sql/execution/streaming/streamingLimits.scala  |   5 +-
 .../execution/benchmark/StateStoreBenchmark.scala  | 633 +++++++++++++++++++++
 ...ngSortWithSessionWindowStateIteratorSuite.scala |   7 +-
 .../streaming/state/MemoryStateStore.scala         |  14 +
 .../state/RocksDBStateStoreIntegrationSuite.scala  |  60 +-
 .../streaming/state/RocksDBStateStoreSuite.scala   |   6 +-
 .../streaming/state/StateStoreRDDSuite.scala       |  18 +-
 .../streaming/state/StateStoreSuite.scala          |  31 +-
 .../StreamingSessionWindowStateManagerSuite.scala  |   4 +-
 .../apache/spark/sql/streaming/StreamSuite.scala   |   4 +-
 .../sql/streaming/StreamingAggregationSuite.scala  |  34 +-
 22 files changed, 1270 insertions(+), 136 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index a00a622..381aeb9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -224,21 +224,23 @@ case class FlatMapGroupsWithStateExec(
           val stateStoreId = StateStoreId(
             stateInfo.get.checkpointLocation, stateInfo.get.operatorId, partitionId)
           val storeProviderId = StateStoreProviderId(stateStoreId, stateInfo.get.queryRunId)
+          // FIXME: would setting prefixScan / evict help?
           val store = StateStore.get(
             storeProviderId,
             groupingAttributes.toStructType,
             stateManager.stateSchema,
-            numColsPrefixKey = 0,
+            StatefulOperatorContext(),
             stateInfo.get.storeVersion, storeConf, hadoopConfBroadcast.value.value)
           val processor = new InputProcessor(store)
           processDataWithPartition(childDataIterator, store, processor, Some(initStateIterator))
       }
     } else {
+      // FIXME: would setting prefixScan / evict help?
       child.execute().mapPartitionsWithStateStore[InternalRow](
         getStateInfo,
         groupingAttributes.toStructType,
         stateManager.stateSchema,
-        numColsPrefixKey = 0,
+        StatefulOperatorContext(),
         session.sqlContext.sessionState,
         Some(session.sqlContext.streams.stateStoreCoordinator)
       ) { case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
@@ -334,6 +336,7 @@ case class FlatMapGroupsWithStateExec(
             throw new IllegalStateException(
               s"Cannot filter timed out keys for $timeoutConf")
         }
+        // FIXME: would setting prefixScan / evict help?
         val timingOutPairs = stateManager.getAllState(store).filter { state =>
           state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold
         }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index 75b7dae..96ba2a3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -100,8 +100,11 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
 
     /** Trait and classes representing the internal state of the store */
     trait STATE
+
     case object UPDATING extends STATE
+
     case object COMMITTED extends STATE
+
     case object ABORTED extends STATE
 
     private val newVersion = version + 1
@@ -195,6 +198,22 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
     override def toString(): String = {
       s"HDFSStateStore[id=(op=${id.operatorId},part=${id.partitionId}),dir=$baseDir]"
     }
+
+    /** FIXME: method doc */
+    override def evictOnWatermark(
+        watermarkMs: Long,
+        altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair] = {
+      // HDFSBackedStateStore doesn't index event time column
+      // FIXME: should we do this for in-memory as well?
+      iterator().filter { pair =>
+        if (altPred.apply(pair)) {
+          remove(pair.key)
+          true
+        } else {
+          false
+        }
+      }
+    }
   }
 
   def getMetricsForProvider(): Map[String, Long] = synchronized {
@@ -219,7 +238,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
 
   private def getLoadedMapForStore(version: Long): HDFSBackedStateStoreMap = synchronized {
     require(version >= 0, "Version cannot be less than 0")
-    val newMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
+    val newMap = HDFSBackedStateStoreMap.create(keySchema, operatorContext.numColsPrefixKey)
     if (version > 0) {
       newMap.putAll(loadMap(version))
     }
@@ -230,7 +249,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
       stateStoreId: StateStoreId,
       keySchema: StructType,
       valueSchema: StructType,
-      numColsPrefixKey: Int,
+      operatorContext: StatefulOperatorContext,
       storeConf: StateStoreConf,
       hadoopConf: Configuration): Unit = {
     this.stateStoreId_ = stateStoreId
@@ -240,10 +259,11 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
     this.hadoopConf = hadoopConf
     this.numberOfVersionsToRetainInMemory = storeConf.maxVersionsToRetainInMemory
 
-    require((keySchema.length == 0 && numColsPrefixKey == 0) ||
-      (keySchema.length > numColsPrefixKey), "The number of columns in the key must be " +
-      "greater than the number of columns for prefix key!")
-    this.numColsPrefixKey = numColsPrefixKey
+    require((keySchema.length == 0 && operatorContext.numColsPrefixKey == 0) ||
+      (keySchema.length > operatorContext.numColsPrefixKey), "The number of columns in the key " +
+      "must be greater than the number of columns for prefix key!")
+
+    this.operatorContext = operatorContext
 
     fm.mkdirs(baseDir)
   }
@@ -283,7 +303,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
   @volatile private var storeConf: StateStoreConf = _
   @volatile private var hadoopConf: Configuration = _
   @volatile private var numberOfVersionsToRetainInMemory: Int = _
-  @volatile private var numColsPrefixKey: Int = 0
+  @volatile private var operatorContext: StatefulOperatorContext = _
 
   // TODO: The validation should be moved to a higher level so that it works for all state store
   // implementations
@@ -401,7 +421,8 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
 
         if (lastAvailableVersion <= 0) {
           // Use an empty map for versions 0 or less.
-          lastAvailableMap = Some(HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey))
+          lastAvailableMap = Some(HDFSBackedStateStoreMap.create(keySchema,
+            operatorContext.numColsPrefixKey))
         } else {
           lastAvailableMap =
             synchronized { Option(loadedMaps.get(lastAvailableVersion)) }
@@ -411,7 +432,8 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
 
       // Load all the deltas from the version after the last available one up to the target version.
       // The last available version is the one with a full snapshot, so it doesn't need deltas.
-      val resultMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
+      val resultMap = HDFSBackedStateStoreMap.create(keySchema,
+        operatorContext.numColsPrefixKey)
       resultMap.putAll(lastAvailableMap.get)
       for (deltaVersion <- lastAvailableVersion + 1 to version) {
         updateFromDeltaFile(deltaVersion, resultMap)
@@ -554,7 +576,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
 
   private def readSnapshotFile(version: Long): Option[HDFSBackedStateStoreMap] = {
     val fileToRead = snapshotFile(version)
-    val map = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
+    val map = HDFSBackedStateStoreMap.create(keySchema, operatorContext.numColsPrefixKey)
     var input: DataInputStream = null
 
     try {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
index 1ff8b41..eed7827 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.execution.streaming.state
 
 import java.io.File
+import java.util
 import java.util.Locale
 import javax.annotation.concurrent.GuardedBy
 
@@ -50,9 +51,12 @@ import org.apache.spark.util.{NextIterator, Utils}
  * @param hadoopConf   Hadoop configuration for talking to the remote file system
  * @param loggingId    Id that will be prepended in logs for isolating concurrent RocksDBs
  */
+// FIXME: optionally receiving column families
 class RocksDB(
     dfsRootDir: String,
     val conf: RocksDBConf,
+    // TODO: change "default" to constant
+    columnFamilies: Seq[String] = Seq("default"),
     localRootDir: File = Utils.createTempDir(),
     hadoopConf: Configuration = new Configuration,
     loggingId: String = "") extends Logging {
@@ -65,16 +69,10 @@ class RocksDB(
   private val flushOptions = new FlushOptions().setWaitForFlush(true)  // wait for flush to complete
   private val writeBatch = new WriteBatchWithIndex(true)  // overwrite multiple updates to a key
 
-  private val bloomFilter = new BloomFilter()
-  private val tableFormatConfig = new BlockBasedTableConfig()
-  tableFormatConfig.setBlockSize(conf.blockSizeKB * 1024)
-  tableFormatConfig.setBlockCache(new LRUCache(conf.blockCacheSizeMB * 1024 * 1024))
-  tableFormatConfig.setFilterPolicy(bloomFilter)
-  tableFormatConfig.setFormatVersion(conf.formatVersion)
-
-  private val dbOptions = new Options() // options to open the RocksDB
+  private val dbOptions: DBOptions = new DBOptions() // options to open the RocksDB
   dbOptions.setCreateIfMissing(true)
-  dbOptions.setTableFormatConfig(tableFormatConfig)
+  dbOptions.setCreateMissingColumnFamilies(true)
+
   private val dbLogger = createLogger() // for forwarding RocksDB native logs to log4j
   dbOptions.setStatistics(new Statistics())
   private val nativeStats = dbOptions.statistics()
@@ -87,6 +85,8 @@ class RocksDB(
   private val acquireLock = new Object
 
   @volatile private var db: NativeRocksDB = _
+  @volatile private var columnFamilyHandles: util.Map[String, ColumnFamilyHandle] = _
+  @volatile private var defaultColumnFamilyHandle: ColumnFamilyHandle = _
   @volatile private var loadedVersion = -1L   // -1 = nothing valid is loaded
   @volatile private var numKeysOnLoadedVersion = 0L
   @volatile private var numKeysOnWritingVersion = 0L
@@ -96,7 +96,7 @@ class RocksDB(
   @volatile private var acquiredThreadInfo: AcquiredThreadInfo = _
 
   private val prefixScanReuseIter =
-    new java.util.concurrent.ConcurrentHashMap[Long, RocksIterator]()
+    new java.util.concurrent.ConcurrentHashMap[(Long, Int), RocksIterator]()
 
   /**
    * Load the given version of data in a native RocksDB instance.
@@ -137,7 +137,28 @@ class RocksDB(
    * @note This will return the last written value even if it was uncommitted.
    */
   def get(key: Array[Byte]): Array[Byte] = {
-    writeBatch.getFromBatchAndDB(db, readOptions, key)
+    get(defaultColumnFamilyHandle, key)
+  }
+
+  // FIXME: method doc
+  def get(cf: String, key: Array[Byte]): Array[Byte] = {
+    get(findColumnFamilyHandle(cf), key)
+  }
+
+  private def get(cfHandle: ColumnFamilyHandle, key: Array[Byte]): Array[Byte] = {
+    writeBatch.getFromBatchAndDB(db, cfHandle, readOptions, key)
+  }
+
+  def merge(key: Array[Byte], value: Array[Byte]): Unit = {
+    merge(defaultColumnFamilyHandle, key, value)
+  }
+
+  def merge(cf: String, key: Array[Byte], value: Array[Byte]): Unit = {
+    merge(findColumnFamilyHandle(cf), key, value)
+  }
+
+  private def merge(cfHandle: ColumnFamilyHandle, key: Array[Byte], value: Array[Byte]): Unit = {
+    writeBatch.merge(cfHandle, key, value)
   }
 
   /**
@@ -145,8 +166,20 @@ class RocksDB(
    * @note This update is not committed to disk until commit() is called.
    */
   def put(key: Array[Byte], value: Array[Byte]): Array[Byte] = {
-    val oldValue = writeBatch.getFromBatchAndDB(db, readOptions, key)
-    writeBatch.put(key, value)
+    put(defaultColumnFamilyHandle, key, value)
+  }
+
+  // FIXME: method doc
+  def put(cf: String, key: Array[Byte], value: Array[Byte]): Array[Byte] = {
+    put(findColumnFamilyHandle(cf), key, value)
+  }
+
+  private def put(
+      cfHandle: ColumnFamilyHandle,
+      key: Array[Byte],
+      value: Array[Byte]): Array[Byte] = {
+    val oldValue = writeBatch.getFromBatchAndDB(db, cfHandle, readOptions, key)
+    writeBatch.put(cfHandle, key, value)
     if (oldValue == null) {
       numKeysOnWritingVersion += 1
     }
@@ -158,9 +191,18 @@ class RocksDB(
    * @note This update is not committed to disk until commit() is called.
    */
   def remove(key: Array[Byte]): Array[Byte] = {
-    val value = writeBatch.getFromBatchAndDB(db, readOptions, key)
+    remove(defaultColumnFamilyHandle, key)
+  }
+
+  // FIXME: method doc
+  def remove(cf: String, key: Array[Byte]): Array[Byte] = {
+    remove(findColumnFamilyHandle(cf), key)
+  }
+
+  private def remove(cfHandle: ColumnFamilyHandle, key: Array[Byte]): Array[Byte] = {
+    val value = writeBatch.getFromBatchAndDB(db, cfHandle, readOptions, key)
     if (value != null) {
-      writeBatch.remove(key)
+      writeBatch.delete(cfHandle, key)
       numKeysOnWritingVersion -= 1
     }
     value
@@ -169,8 +211,17 @@ class RocksDB(
   /**
    * Get an iterator of all committed and uncommitted key-value pairs.
    */
-  def iterator(): Iterator[ByteArrayPair] = {
-    val iter = writeBatch.newIteratorWithBase(db.newIterator())
+  def iterator(): NextIterator[ByteArrayPair] = {
+    iterator(defaultColumnFamilyHandle)
+  }
+
+  // FIXME: doc
+  def iterator(cf: String): NextIterator[ByteArrayPair] = {
+    iterator(findColumnFamilyHandle(cf))
+  }
+
+  private def iterator(cfHandle: ColumnFamilyHandle): NextIterator[ByteArrayPair] = {
+    val iter = writeBatch.newIteratorWithBase(cfHandle, db.newIterator(cfHandle))
     logInfo(s"Getting iterator from version $loadedVersion")
     iter.seekToFirst()
 
@@ -197,11 +248,20 @@ class RocksDB(
   }
 
   def prefixScan(prefix: Array[Byte]): Iterator[ByteArrayPair] = {
+    prefixScan(defaultColumnFamilyHandle, prefix)
+  }
+
+  def prefixScan(cf: String, prefix: Array[Byte]): Iterator[ByteArrayPair] = {
+    prefixScan(findColumnFamilyHandle(cf), prefix)
+  }
+
+  private def prefixScan(
+      cfHandle: ColumnFamilyHandle, prefix: Array[Byte]): Iterator[ByteArrayPair] = {
     val threadId = Thread.currentThread().getId
-    val iter = prefixScanReuseIter.computeIfAbsent(threadId, tid => {
-      val it = writeBatch.newIteratorWithBase(db.newIterator())
+    val iter = prefixScanReuseIter.computeIfAbsent((threadId, cfHandle.getID), key => {
+      val it = writeBatch.newIteratorWithBase(cfHandle, db.newIterator(cfHandle))
       logInfo(s"Getting iterator from version $loadedVersion for prefix scan on " +
-        s"thread ID $tid")
+        s"thread ID ${key._1} and column family ID ${key._2}")
       it
     })
 
@@ -223,6 +283,14 @@ class RocksDB(
     }
   }
 
+  private def findColumnFamilyHandle(cf: String): ColumnFamilyHandle = {
+    val cfHandle = columnFamilyHandles.get(cf)
+    if (cfHandle == null) {
+      throw new IllegalArgumentException(s"Handle for column family $cf is not found")
+    }
+    cfHandle
+  }
+
   /**
    * Commit all the updates made as a version to DFS. The steps it needs to do to commits are:
    * - Write all the updates to the native RocksDB
@@ -242,11 +310,16 @@ class RocksDB(
       val writeTimeMs = timeTakenMs { db.write(writeOptions, writeBatch) }
 
       logInfo(s"Flushing updates for $newVersion")
-      val flushTimeMs = timeTakenMs { db.flush(flushOptions) }
+      val flushTimeMs = timeTakenMs {
+        db.flush(flushOptions,
+          new util.ArrayList[ColumnFamilyHandle](columnFamilyHandles.values()))
+      }
 
       val compactTimeMs = if (conf.compactOnCommit) {
         logInfo("Compacting")
-        timeTakenMs { db.compactRange() }
+        timeTakenMs {
+          columnFamilyHandles.values().forEach(cfHandle => db.compactRange(cfHandle))
+        }
       } else 0
 
       logInfo("Pausing background work")
@@ -279,6 +352,7 @@ class RocksDB(
       loadedVersion
     } catch {
       case t: Throwable =>
+        logWarning(s"ERROR! exc: $t", t)
         loadedVersion = -1  // invalidate loaded version
         throw t
     } finally {
@@ -422,12 +496,43 @@ class RocksDB(
 
   private def openDB(): Unit = {
     assert(db == null)
-    db = NativeRocksDB.open(dbOptions, workingDir.toString)
+
+    val columnFamilyDescriptors = new util.ArrayList[ColumnFamilyDescriptor]()
+    columnFamilies.foreach { cf =>
+      val bloomFilter = new BloomFilter()
+      val tableFormatConfig = new BlockBasedTableConfig()
+      tableFormatConfig.setBlockSize(conf.blockSizeKB * 1024)
+      tableFormatConfig.setBlockCache(new LRUCache(conf.blockCacheSizeMB * 1024 * 1024))
+      tableFormatConfig.setFilterPolicy(bloomFilter)
+      tableFormatConfig.setFormatVersion(conf.formatVersion)
+
+      val columnFamilyOptions = new ColumnFamilyOptions()
+      columnFamilyOptions.setTableFormatConfig(tableFormatConfig)
+      columnFamilyDescriptors.add(new ColumnFamilyDescriptor(cf.getBytes(), columnFamilyOptions))
+    }
+
+    val cfHandles = new util.ArrayList[ColumnFamilyHandle](columnFamilyDescriptors.size())
+    db = NativeRocksDB.open(dbOptions, workingDir.toString, columnFamilyDescriptors,
+      cfHandles)
+
+    columnFamilyHandles = new util.HashMap[String, ColumnFamilyHandle]()
+    columnFamilies.indices.foreach { idx =>
+      columnFamilyHandles.put(columnFamilies(idx), cfHandles.get(idx))
+    }
+
+    // FIXME: constant
+    defaultColumnFamilyHandle = columnFamilyHandles.get("default")
+
     logInfo(s"Opened DB with conf ${conf}")
   }
 
   private def closeDB(): Unit = {
     if (db != null) {
+      columnFamilyHandles.entrySet().forEach(pair => db.destroyColumnFamilyHandle(pair.getValue))
+      columnFamilyHandles.clear()
+      columnFamilyHandles = null
+      defaultColumnFamilyHandle = null
+
       db.close()
       db = null
     }
@@ -441,10 +546,17 @@ class RocksDB(
         // Warn is mapped to info because RocksDB warn is too verbose
         // (e.g. dumps non-warning stuff like stats)
         val loggingFunc: ( => String) => Unit = infoLogLevel match {
+          /*
           case InfoLogLevel.FATAL_LEVEL | InfoLogLevel.ERROR_LEVEL => logError(_)
           case InfoLogLevel.WARN_LEVEL | InfoLogLevel.INFO_LEVEL => logInfo(_)
           case InfoLogLevel.DEBUG_LEVEL => logDebug(_)
           case _ => logTrace(_)
+           */
+          case InfoLogLevel.FATAL_LEVEL | InfoLogLevel.ERROR_LEVEL => logError(_)
+          case InfoLogLevel.WARN_LEVEL => logWarning(_)
+          case InfoLogLevel.INFO_LEVEL => logInfo(_)
+          case InfoLogLevel.DEBUG_LEVEL => logDebug(_)
+          case _ => logTrace(_)
         }
         loggingFunc(s"[NativeRocksDB-${infoLogLevel.getValue}] $logMsg")
       }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
index 81755e5..323826d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
@@ -17,9 +17,12 @@
 
 package org.apache.spark.sql.execution.streaming.state
 
+import java.lang.{Long => JLong}
+import java.nio.ByteOrder
+
 import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow, UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{STATE_ENCODING_NUM_VERSION_BYTES, STATE_ENCODING_VERSION}
-import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.sql.types.{StructField, StructType, TimestampType}
 import org.apache.spark.unsafe.Platform
 
 sealed trait RocksDBStateEncoder {
@@ -27,6 +30,11 @@ sealed trait RocksDBStateEncoder {
   def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte]
   def extractPrefixKey(key: UnsafeRow): UnsafeRow
 
+  def supportEventTimeIndex: Boolean
+  def extractEventTime(key: UnsafeRow): Long
+  def encodeEventTimeIndexKey(timestamp: Long, encodedKey: Array[Byte]): Array[Byte]
+  def decodeEventTimeIndexKey(eventTimeBytes: Array[Byte]): (Long, Array[Byte])
+
   def encodeKey(row: UnsafeRow): Array[Byte]
   def encodeValue(row: UnsafeRow): Array[Byte]
 
@@ -39,11 +47,13 @@ object RocksDBStateEncoder {
   def getEncoder(
       keySchema: StructType,
       valueSchema: StructType,
-      numColsPrefixKey: Int): RocksDBStateEncoder = {
+      numColsPrefixKey: Int,
+      eventTimeColIdx: Array[Int]): RocksDBStateEncoder = {
     if (numColsPrefixKey > 0) {
+      // FIXME: need to deal with prefix case as well
       new PrefixKeyScanStateEncoder(keySchema, valueSchema, numColsPrefixKey)
     } else {
-      new NoPrefixKeyStateEncoder(keySchema, valueSchema)
+      new NoPrefixKeyStateEncoder(keySchema, valueSchema, eventTimeColIdx)
     }
   }
 
@@ -86,6 +96,39 @@ object RocksDBStateEncoder {
   }
 }
 
+object BinarySortable {
+  private val NATIVE_BYTE_ORDER: ByteOrder = ByteOrder.nativeOrder()
+  private val LITTLE_ENDIAN: Boolean = NATIVE_BYTE_ORDER == ByteOrder.LITTLE_ENDIAN
+  private val SIGN_BIT_LONG: Long = (1L << 63)
+
+  def encodeToBinarySortableLong(value: Long): Long = {
+    // Flip the sign bit. This simply works with binary form of comparison, as negative values
+    // are placed in reversed order, and positive values are placed in sequential order.
+    val encoded = value ^ SIGN_BIT_LONG
+
+    // We have to retain the sequence of bytes as same as BIG_ENDIAN, as the binary form will be
+    // compared in sequential order (via offset).
+    if (LITTLE_ENDIAN) {
+      JLong.reverseBytes(encoded)
+    } else {
+      encoded
+    }
+  }
+
+  def decodeBinarySortableLong(encoded: Long): Long = {
+    // The value is based on BIG_ENDIAN. If the system is LITTLE_ENDIAN, we should convert it to
+    // follow the system.
+    val decoded = if (LITTLE_ENDIAN) {
+      JLong.reverseBytes(encoded)
+    } else {
+      encoded
+    }
+
+    // Flip the sign bit as encode function does.
+    decoded ^ SIGN_BIT_LONG
+  }
+}
+
 class PrefixKeyScanStateEncoder(
     keySchema: StructType,
     valueSchema: StructType,
@@ -185,6 +228,23 @@ class PrefixKeyScanStateEncoder(
   }
 
   override def supportPrefixKeyScan: Boolean = true
+
+  override def supportEventTimeIndex: Boolean = false
+
+  // FIXME: fix me!
+  def extractEventTime(key: UnsafeRow): Long = {
+    throw new IllegalStateException("This encoder doesn't support event time index!")
+  }
+
+  // FIXME: fix me!
+  def encodeEventTimeIndexKey(timestamp: Long, encodedKey: Array[Byte]): Array[Byte] = {
+    throw new IllegalStateException("This encoder doesn't support event time index!")
+  }
+
+  // FIXME: fix me!
+  def decodeEventTimeIndexKey(eventTimeBytes: Array[Byte]): (Long, Array[Byte]) = {
+    throw new IllegalStateException("This encoder doesn't support event time index!")
+  }
 }
 
 /**
@@ -197,8 +257,10 @@ class PrefixKeyScanStateEncoder(
  *    (offset 0 is the version byte of value 0). That is, if the unsafe row has N bytes,
  *    then the generated array byte will be N+1 bytes.
  */
-class NoPrefixKeyStateEncoder(keySchema: StructType, valueSchema: StructType)
-  extends RocksDBStateEncoder {
+class NoPrefixKeyStateEncoder(
+    keySchema: StructType,
+    valueSchema: StructType,
+    eventTimeColIdx: Array[Int]) extends RocksDBStateEncoder {
 
   import RocksDBStateEncoder._
 
@@ -207,6 +269,32 @@ class NoPrefixKeyStateEncoder(keySchema: StructType, valueSchema: StructType)
   private val valueRow = new UnsafeRow(valueSchema.size)
   private val rowTuple = new UnsafeRowPair()
 
+  validateColumnTypeOnEventTimeColumn()
+
+  private def validateColumnTypeOnEventTimeColumn(): Unit = {
+    if (eventTimeColIdx.nonEmpty) {
+      var curSchema: StructType = keySchema
+      eventTimeColIdx.dropRight(1).foreach { idx =>
+        curSchema(idx).dataType match {
+          case stType: StructType =>
+            curSchema = stType
+          case _ =>
+            // FIXME: better error message
+            throw new IllegalStateException("event time column is not properly specified! " +
+              s"index: ${eventTimeColIdx.mkString("(", ", ", ")")} / key schema: $keySchema")
+        }
+      }
+
+      curSchema(eventTimeColIdx.last).dataType match {
+        case _: TimestampType =>
+        case _ =>
+          // FIXME: better error message
+          throw new IllegalStateException("event time column is not properly specified! " +
+            s"index: ${eventTimeColIdx.mkString("(", ", ", ")")} / key schema: $keySchema")
+      }
+    }
+  }
+
   override def encodeKey(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)
 
   override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)
@@ -249,4 +337,41 @@ class NoPrefixKeyStateEncoder(keySchema: StructType, valueSchema: StructType)
   override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = {
     throw new IllegalStateException("This encoder doesn't support prefix key!")
   }
+
+  override def supportEventTimeIndex: Boolean = eventTimeColIdx.nonEmpty
+
+  override def extractEventTime(key: UnsafeRow): Long = {
+    var curRow: UnsafeRow = key
+    var curSchema: StructType = keySchema
+
+    eventTimeColIdx.dropRight(1).foreach { idx =>
+      // validation is done in initialization phase
+      curSchema = curSchema(idx).dataType.asInstanceOf[StructType]
+      curRow = curRow.getStruct(idx, curSchema.length)
+    }
+
+    curRow.getLong(eventTimeColIdx.last) / 1000
+  }
+
+  override def encodeEventTimeIndexKey(timestamp: Long, encodedKey: Array[Byte]): Array[Byte] = {
+    val newKey = new Array[Byte](8 + encodedKey.length)
+
+    Platform.putLong(newKey, Platform.BYTE_ARRAY_OFFSET,
+      BinarySortable.encodeToBinarySortableLong(timestamp))
+    Platform.copyMemory(encodedKey, Platform.BYTE_ARRAY_OFFSET, newKey,
+      Platform.BYTE_ARRAY_OFFSET + 8, encodedKey.length)
+
+    newKey
+  }
+
+  override def decodeEventTimeIndexKey(eventTimeBytes: Array[Byte]): (Long, Array[Byte]) = {
+    val encoded = Platform.getLong(eventTimeBytes, Platform.BYTE_ARRAY_OFFSET)
+    val timestamp = BinarySortable.decodeBinarySortableLong(encoded)
+
+    val key = new Array[Byte](eventTimeBytes.length - 8)
+    Platform.copyMemory(eventTimeBytes, Platform.BYTE_ARRAY_OFFSET + 8,
+      key, Platform.BYTE_ARRAY_OFFSET, eventTimeBytes.length - 8)
+
+    (timestamp, key)
+  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
index a2b33c2..1d66220 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
@@ -25,7 +25,8 @@ import org.apache.spark.{SparkConf, SparkEnv}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.util.Utils
+import org.apache.spark.unsafe.Platform
+import org.apache.spark.util.{NextIterator, Utils}
 
 private[sql] class RocksDBStateStoreProvider
   extends StateStoreProvider with Logging with Closeable {
@@ -60,12 +61,23 @@ private[sql] class RocksDBStateStoreProvider
       verify(state == UPDATING, "Cannot put after already committed or aborted")
       verify(key != null, "Key cannot be null")
       require(value != null, "Cannot put a null value")
-      rocksDB.put(encoder.encodeKey(key), encoder.encodeValue(value))
+
+      val encodedKey = encoder.encodeKey(key)
+      val encodedValue = encoder.encodeValue(value)
+      rocksDB.put(encodedKey, encodedValue)
+
+      if (encoder.supportEventTimeIndex) {
+        val timestamp = encoder.extractEventTime(key)
+        val tsKey = encoder.encodeEventTimeIndexKey(timestamp, encodedKey)
+
+        rocksDB.put(CF_EVENT_TIME_INDEX, tsKey, Array.empty)
+      }
     }
 
     override def remove(key: UnsafeRow): Unit = {
       verify(state == UPDATING, "Cannot remove after already committed or aborted")
       verify(key != null, "Key cannot be null")
+      // FIXME: this should reflect the index
       rocksDB.remove(encoder.encodeKey(key))
     }
 
@@ -161,13 +173,75 @@ private[sql] class RocksDBStateStoreProvider
 
     /** Return the [[RocksDB]] instance in this store. This is exposed mainly for testing. */
     def dbInstance(): RocksDB = rocksDB
+
+    /** FIXME: method doc */
+    override def evictOnWatermark(
+        watermarkMs: Long,
+        altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair] = {
+
+      if (encoder.supportEventTimeIndex) {
+        val kv = new ByteArrayPair()
+
+        // FIXME: DEBUG
+        // logWarning(s"DEBUG: start iterating event time index, watermark $watermarkMs")
+
+        new NextIterator[UnsafeRowPair] {
+          private val iter = rocksDB.iterator(CF_EVENT_TIME_INDEX)
+          override protected def getNext(): UnsafeRowPair = {
+            if (iter.hasNext) {
+              val pair = iter.next()
+
+              val encodedTs = Platform.getLong(pair.key, Platform.BYTE_ARRAY_OFFSET)
+              val decodedTs = BinarySortable.decodeBinarySortableLong(encodedTs)
+
+              // FIXME: DEBUG
+              // logWarning(s"DEBUG: decoded TS: $decodedTs")
+
+              if (decodedTs > watermarkMs) {
+                finished = true
+                null
+              } else {
+                // FIXME: can we leverage deleteRange to bulk delete on index?
+                rocksDB.remove(CF_EVENT_TIME_INDEX, pair.key)
+                val (_, encodedKey) = encoder.decodeEventTimeIndexKey(pair.key)
+                val value = rocksDB.get(encodedKey)
+                if (value == null) {
+                  throw new IllegalStateException("Event time index has been broken!")
+                }
+                kv.set(encodedKey, value)
+                val rowPair = encoder.decode(kv)
+                rocksDB.remove(encodedKey)
+                rowPair
+              }
+            } else {
+              finished = true
+              null
+            }
+          }
+
+          override protected def close(): Unit = {
+            iter.closeIfNeeded()
+          }
+        }
+      } else {
+        rocksDB.iterator().flatMap { kv =>
+          val rowPair = encoder.decode(kv)
+          if (altPred(rowPair)) {
+            rocksDB.remove(kv.key)
+            Some(rowPair)
+          } else {
+            None
+          }
+        }
+      }
+    }
   }
 
   override def init(
       stateStoreId: StateStoreId,
       keySchema: StructType,
       valueSchema: StructType,
-      numColsPrefixKey: Int,
+      operatorContext: StatefulOperatorContext,
       storeConf: StateStoreConf,
       hadoopConf: Configuration): Unit = {
     this.stateStoreId_ = stateStoreId
@@ -176,11 +250,14 @@ private[sql] class RocksDBStateStoreProvider
     this.storeConf = storeConf
     this.hadoopConf = hadoopConf
 
-    require((keySchema.length == 0 && numColsPrefixKey == 0) ||
-      (keySchema.length > numColsPrefixKey), "The number of columns in the key must be " +
-      "greater than the number of columns for prefix key!")
+    require((keySchema.length == 0 && operatorContext.numColsPrefixKey == 0) ||
+      (keySchema.length > operatorContext.numColsPrefixKey), "The number of columns in the key " +
+      "must be greater than the number of columns for prefix key!")
 
-    this.encoder = RocksDBStateEncoder.getEncoder(keySchema, valueSchema, numColsPrefixKey)
+    this.operatorContext = operatorContext
+
+    this.encoder = RocksDBStateEncoder.getEncoder(keySchema, valueSchema,
+      operatorContext.numColsPrefixKey, operatorContext.eventTimeColIdx)
 
     rocksDB // lazy initialization
   }
@@ -212,6 +289,7 @@ private[sql] class RocksDBStateStoreProvider
   @volatile private var valueSchema: StructType = _
   @volatile private var storeConf: StateStoreConf = _
   @volatile private var hadoopConf: Configuration = _
+  @volatile private var operatorContext: StatefulOperatorContext = _
 
   private[sql] lazy val rocksDB = {
     val dfsRootDir = stateStoreId.storeCheckpointLocation().toString
@@ -219,7 +297,10 @@ private[sql] class RocksDBStateStoreProvider
       s"partId=${stateStoreId.partitionId},name=${stateStoreId.storeName})"
     val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
     val localRootDir = Utils.createTempDir(Utils.getLocalDir(sparkConf), storeIdStr)
-    new RocksDB(dfsRootDir, RocksDBConf(storeConf), localRootDir, hadoopConf, storeIdStr)
+    new RocksDB(dfsRootDir, RocksDBConf(storeConf),
+      columnFamilies = Seq("default", RocksDBStateStoreProvider.CF_EVENT_TIME_INDEX),
+      localRootDir = localRootDir,
+      hadoopConf = hadoopConf, loggingId = storeIdStr)
   }
 
   @volatile private var encoder: RocksDBStateEncoder = _
@@ -234,6 +315,9 @@ object RocksDBStateStoreProvider {
   val STATE_ENCODING_NUM_VERSION_BYTES = 1
   val STATE_ENCODING_VERSION: Byte = 0
 
+  // reserved column families
+  val CF_EVENT_TIME_INDEX: String = "__event_time_idx"
+
   // Native operation latencies report as latency in microseconds
   // as SQLMetrics support millis. Convert the value to millis
   val CUSTOM_METRIC_GET_TIME = StateStoreCustomTimingMetric(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index 5020638..cee9ad1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -130,6 +130,11 @@ trait StateStore extends ReadStateStore {
    */
   override def iterator(): Iterator[UnsafeRowPair]
 
+  /** FIXME: method doc */
+  def evictOnWatermark(
+      watermarkMs: Long,
+      altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair]
+
   /** Current metrics of the state store */
   def metrics: StateStoreMetrics
 
@@ -229,6 +234,19 @@ class InvalidUnsafeRowException
     "checkpoint or use the legacy Spark version to process the streaming state.", null)
 
 /**
+ * FIXME: classdoc
+ *
+ * @param numColsPrefixKey The number of leftmost columns to be used as prefix key.
+ *                         A value not greater than 0 means the operator doesn't activate prefix
+ *                         key, and the operator should not call prefixScan method in StateStore.
+ * @param eventTimeColIdx  column specifying event time for the row. only works when the column
+ *                         is in the key. array type as the column can be struct type.
+ */
+case class StatefulOperatorContext(
+    numColsPrefixKey: Int = 0,
+    eventTimeColIdx: Array[Int] = Array.empty)
+
+/**
  * Trait representing a provider that provide [[StateStore]] instances representing
  * versions of state data.
  *
@@ -255,9 +273,7 @@ trait StateStoreProvider {
    * @param stateStoreId Id of the versioned StateStores that this provider will generate
    * @param keySchema Schema of keys to be stored
    * @param valueSchema Schema of value to be stored
-   * @param numColsPrefixKey The number of leftmost columns to be used as prefix key.
-   *                         A value not greater than 0 means the operator doesn't activate prefix
-   *                         key, and the operator should not call prefixScan method in StateStore.
+   * @param operatorContext FIXME: ...
    * @param storeConfs Configurations used by the StateStores
    * @param hadoopConf Hadoop configuration that could be used by StateStore to save state data
    */
@@ -265,7 +281,7 @@ trait StateStoreProvider {
       stateStoreId: StateStoreId,
       keySchema: StructType,
       valueSchema: StructType,
-      numColsPrefixKey: Int,
+      operatorContext: StatefulOperatorContext,
       storeConfs: StateStoreConf,
       hadoopConf: Configuration): Unit
 
@@ -318,11 +334,11 @@ object StateStoreProvider {
       providerId: StateStoreProviderId,
       keySchema: StructType,
       valueSchema: StructType,
-      numColsPrefixKey: Int,
+      operatorContext: StatefulOperatorContext,
       storeConf: StateStoreConf,
       hadoopConf: Configuration): StateStoreProvider = {
     val provider = create(storeConf.providerClass)
-    provider.init(providerId.storeId, keySchema, valueSchema, numColsPrefixKey,
+    provider.init(providerId.storeId, keySchema, valueSchema, operatorContext,
       storeConf, hadoopConf)
     provider
   }
@@ -471,13 +487,13 @@ object StateStore extends Logging {
       storeProviderId: StateStoreProviderId,
       keySchema: StructType,
       valueSchema: StructType,
-      numColsPrefixKey: Int,
+      operatorContext: StatefulOperatorContext,
       version: Long,
       storeConf: StateStoreConf,
       hadoopConf: Configuration): ReadStateStore = {
     require(version >= 0)
     val storeProvider = getStateStoreProvider(storeProviderId, keySchema, valueSchema,
-      numColsPrefixKey, storeConf, hadoopConf)
+      operatorContext, storeConf, hadoopConf)
     storeProvider.getReadStore(version)
   }
 
@@ -486,13 +502,13 @@ object StateStore extends Logging {
       storeProviderId: StateStoreProviderId,
       keySchema: StructType,
       valueSchema: StructType,
-      numColsPrefixKey: Int,
+      operatorContext: StatefulOperatorContext,
       version: Long,
       storeConf: StateStoreConf,
       hadoopConf: Configuration): StateStore = {
     require(version >= 0)
     val storeProvider = getStateStoreProvider(storeProviderId, keySchema, valueSchema,
-      numColsPrefixKey, storeConf, hadoopConf)
+      operatorContext, storeConf, hadoopConf)
     storeProvider.getStore(version)
   }
 
@@ -500,7 +516,7 @@ object StateStore extends Logging {
       storeProviderId: StateStoreProviderId,
       keySchema: StructType,
       valueSchema: StructType,
-      numColsPrefixKey: Int,
+      operatorContext: StatefulOperatorContext,
       storeConf: StateStoreConf,
       hadoopConf: Configuration): StateStoreProvider = {
     loadedProviders.synchronized {
@@ -527,7 +543,7 @@ object StateStore extends Logging {
       val provider = loadedProviders.getOrElseUpdate(
         storeProviderId,
         StateStoreProvider.createAndInit(
-          storeProviderId, keySchema, valueSchema, numColsPrefixKey, storeConf, hadoopConf)
+          storeProviderId, keySchema, valueSchema, operatorContext, storeConf, hadoopConf)
       )
       val otherProviderIds = loadedProviders.keys.filter(_ != storeProviderId).toSeq
       val providerIdsToUnload = reportActiveStoreInstance(storeProviderId, otherProviderIds)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
index fbe83ad..33d83b7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
@@ -74,7 +74,7 @@ class ReadStateStoreRDD[T: ClassTag, U: ClassTag](
     storeVersion: Long,
     keySchema: StructType,
     valueSchema: StructType,
-    numColsPrefixKey: Int,
+    operatorContext: StatefulOperatorContext,
     sessionState: SessionState,
     @transient private val storeCoordinator: Option[StateStoreCoordinatorRef],
     extraOptions: Map[String, String] = Map.empty)
@@ -87,7 +87,7 @@ class ReadStateStoreRDD[T: ClassTag, U: ClassTag](
     val storeProviderId = getStateProviderId(partition)
 
     val store = StateStore.getReadOnly(
-      storeProviderId, keySchema, valueSchema, numColsPrefixKey, storeVersion,
+      storeProviderId, keySchema, valueSchema, operatorContext, storeVersion,
       storeConf, hadoopConfBroadcast.value.value)
     val inputIter = dataRDD.iterator(partition, ctxt)
     storeReadFunction(store, inputIter)
@@ -108,7 +108,7 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
     storeVersion: Long,
     keySchema: StructType,
     valueSchema: StructType,
-    numColsPrefixKey: Int,
+    operatorContext: StatefulOperatorContext,
     sessionState: SessionState,
     @transient private val storeCoordinator: Option[StateStoreCoordinatorRef],
     extraOptions: Map[String, String] = Map.empty)
@@ -121,7 +121,7 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
     val storeProviderId = getStateProviderId(partition)
 
     val store = StateStore.get(
-      storeProviderId, keySchema, valueSchema, numColsPrefixKey, storeVersion,
+      storeProviderId, keySchema, valueSchema, operatorContext, storeVersion,
       storeConf, hadoopConfBroadcast.value.value)
     val inputIter = dataRDD.iterator(partition, ctxt)
     storeUpdateFunction(store, inputIter)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala
index 36138f1..c6b63cf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala
@@ -51,6 +51,12 @@ sealed trait StreamingAggregationStateManager extends Serializable {
   /** Remove a single non-null key from the target state store. */
   def remove(store: StateStore, key: UnsafeRow): Unit
 
+  // FIXME: method doc!
+  def evictOnWatermark(
+    store: StateStore,
+    watermarkMs: Long,
+    altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair]
+
   /** Return an iterator containing all the key-value pairs in target state store. */
   def iterator(store: ReadStateStore): Iterator[UnsafeRowPair]
 
@@ -128,6 +134,13 @@ class StreamingAggregationStateManagerImplV1(
   override def values(store: ReadStateStore): Iterator[UnsafeRow] = {
     store.iterator().map(_.value)
   }
+
+  override def evictOnWatermark(
+      store: StateStore,
+      watermarkMs: Long,
+      altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair] = {
+    store.evictOnWatermark(watermarkMs, altPred)
+  }
 }
 
 /**
@@ -186,6 +199,16 @@ class StreamingAggregationStateManagerImplV2(
     store.iterator().map(rowPair => new UnsafeRowPair(rowPair.key, restoreOriginalRow(rowPair)))
   }
 
+
+  override def evictOnWatermark(
+      store: StateStore,
+      watermarkMs: Long,
+      altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair] = {
+    store.evictOnWatermark(watermarkMs, altPred).map { rowPair =>
+      new UnsafeRowPair(rowPair.key, restoreOriginalRow(rowPair))
+    }
+  }
+
   override def values(store: ReadStateStore): Iterator[UnsafeRow] = {
     store.iterator().map(rowPair => restoreOriginalRow(rowPair))
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
index f301d23..ce56845 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
@@ -363,8 +363,9 @@ class SymmetricHashJoinStateManager(
     protected def getStateStore(keySchema: StructType, valueSchema: StructType): StateStore = {
       val storeProviderId = StateStoreProviderId(
         stateInfo.get, partitionId, getStateStoreName(joinSide, stateStoreType))
+      // FIXME: would setting prefixScan / evict help?
       val store = StateStore.get(
-        storeProviderId, keySchema, valueSchema, numColsPrefixKey = 0,
+        storeProviderId, keySchema, valueSchema, StatefulOperatorContext(),
         stateInfo.get.storeVersion, storeConf, hadoopConf)
       logInfo(s"Loaded store ${store.id}")
       store
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
index 01ff72b..7cd25eb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
@@ -35,14 +35,14 @@ package object state {
         stateInfo: StatefulOperatorStateInfo,
         keySchema: StructType,
         valueSchema: StructType,
-        numColsPrefixKey: Int)(
+        operatorContext: StatefulOperatorContext)(
         storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = {
 
       mapPartitionsWithStateStore(
         stateInfo,
         keySchema,
         valueSchema,
-        numColsPrefixKey,
+        operatorContext,
         sqlContext.sessionState,
         Some(sqlContext.streams.stateStoreCoordinator))(
         storeUpdateFunction)
@@ -53,7 +53,7 @@ package object state {
         stateInfo: StatefulOperatorStateInfo,
         keySchema: StructType,
         valueSchema: StructType,
-        numColsPrefixKey: Int,
+        operatorContext: StatefulOperatorContext,
         sessionState: SessionState,
         storeCoordinator: Option[StateStoreCoordinatorRef],
         extraOptions: Map[String, String] = Map.empty)(
@@ -77,7 +77,7 @@ package object state {
         stateInfo.storeVersion,
         keySchema,
         valueSchema,
-        numColsPrefixKey,
+        operatorContext,
         sessionState,
         storeCoordinator,
         extraOptions)
@@ -88,7 +88,7 @@ package object state {
         stateInfo: StatefulOperatorStateInfo,
         keySchema: StructType,
         valueSchema: StructType,
-        numColsPrefixKey: Int,
+        operatorContext: StatefulOperatorContext,
         sessionState: SessionState,
         storeCoordinator: Option[StateStoreCoordinatorRef],
         extraOptions: Map[String, String] = Map.empty)(
@@ -112,7 +112,7 @@ package object state {
         stateInfo.storeVersion,
         keySchema,
         valueSchema,
-        numColsPrefixKey,
+        operatorContext,
         sessionState,
         storeCoordinator,
         extraOptions)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index 3431823..923d94e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -237,11 +237,10 @@ trait WatermarkSupport extends SparkPlan {
   protected def removeKeysOlderThanWatermark(store: StateStore): Unit = {
     if (watermarkPredicateForKeys.nonEmpty) {
       val numRemovedStateRows = longMetric("numRemovedStateRows")
-      store.iterator().foreach { rowPair =>
-        if (watermarkPredicateForKeys.get.eval(rowPair.key)) {
-          store.remove(rowPair.key)
-          numRemovedStateRows += 1
-        }
+      store.evictOnWatermark(eventTimeWatermark.get, pair => {
+        watermarkPredicateForKeys.get.eval(pair.key)
+      }).foreach { _ =>
+        numRemovedStateRows += 1
       }
     }
   }
@@ -251,11 +250,10 @@ trait WatermarkSupport extends SparkPlan {
       store: StateStore): Unit = {
     if (watermarkPredicateForKeys.nonEmpty) {
       val numRemovedStateRows = longMetric("numRemovedStateRows")
-      storeManager.keys(store).foreach { keyRow =>
-        if (watermarkPredicateForKeys.get.eval(keyRow)) {
-          storeManager.remove(store, keyRow)
-          numRemovedStateRows += 1
-        }
+      storeManager.evictOnWatermark(store,
+        eventTimeWatermark.get, pair => watermarkPredicateForKeys.get.eval(pair.key)
+      ).foreach { _ =>
+        numRemovedStateRows += 1
       }
     }
   }
@@ -307,7 +305,8 @@ case class StateStoreRestoreExec(
       getStateInfo,
       keyExpressions.toStructType,
       stateManager.getStateValueSchema,
-      numColsPrefixKey = 0,
+      // FIXME: set event time column here!
+      StatefulOperatorContext(),
       session.sessionState,
       Some(session.streams.stateStoreCoordinator)) { case (store, iter) =>
         val hasInput = iter.hasNext
@@ -365,11 +364,26 @@ case class StateStoreSaveExec(
     assert(outputMode.nonEmpty,
       "Incorrect planning in IncrementalExecution, outputMode has not been set")
 
+    val eventTimeIdx = keyExpressions.indexWhere(_.metadata.contains(EventTimeWatermark.delayKey))
+    val eventTimeColIdx = if (eventTimeIdx >= 0) {
+      keyExpressions.toStructType(eventTimeIdx).dataType match {
+        // FIXME: for now, we only consider window operation here, as we do the same in
+        //   WatermarkSupport.watermarkExpression
+        case StructType(_) => Array[Int](eventTimeIdx, 1)
+        case TimestampType => Array[Int](eventTimeIdx)
+        case _ => throw new IllegalStateException(
+          "The type of event time column should be timestamp")
+      }
+    } else {
+      Array.empty[Int]
+    }
+
     child.execute().mapPartitionsWithStateStore(
       getStateInfo,
       keyExpressions.toStructType,
       stateManager.getStateValueSchema,
-      numColsPrefixKey = 0,
+      // FIXME: set event time column here!
+      StatefulOperatorContext(eventTimeColIdx = eventTimeColIdx),
       session.sessionState,
       Some(session.streams.stateStoreCoordinator)) { (store, iter) =>
         val numOutputRows = longMetric("numOutputRows")
@@ -414,18 +428,16 @@ case class StateStoreSaveExec(
             }
 
             val removalStartTimeNs = System.nanoTime
-            val rangeIter = stateManager.iterator(store)
+            val evictedIter = stateManager.evictOnWatermark(store,
+              eventTimeWatermark.get, pair => watermarkPredicateForKeys.get.eval(pair.key))
 
             new NextIterator[InternalRow] {
               override protected def getNext(): InternalRow = {
                 var removedValueRow: InternalRow = null
-                while(rangeIter.hasNext && removedValueRow == null) {
-                  val rowPair = rangeIter.next()
-                  if (watermarkPredicateForKeys.get.eval(rowPair.key)) {
-                    stateManager.remove(store, rowPair.key)
-                    numRemovedStateRows += 1
-                    removedValueRow = rowPair.value
-                  }
+                while(evictedIter.hasNext && removedValueRow == null) {
+                  val rowPair = evictedIter.next()
+                  numRemovedStateRows += 1
+                  removedValueRow = rowPair.value
                 }
                 if (removedValueRow == null) {
                   finished = true
@@ -541,7 +553,8 @@ case class SessionWindowStateStoreRestoreExec(
       getStateInfo,
       stateManager.getStateKeySchema,
       stateManager.getStateValueSchema,
-      numColsPrefixKey = stateManager.getNumColsForPrefixKey,
+      // FIXME: set event time column here!
+      StatefulOperatorContext(stateManager.getNumColsForPrefixKey),
       session.sessionState,
       Some(session.streams.stateStoreCoordinator)) { case (store, iter) =>
 
@@ -618,7 +631,8 @@ case class SessionWindowStateStoreSaveExec(
       getStateInfo,
       stateManager.getStateKeySchema,
       stateManager.getStateValueSchema,
-      numColsPrefixKey = stateManager.getNumColsForPrefixKey,
+      // FIXME: set event time column!
+      StatefulOperatorContext(numColsPrefixKey = stateManager.getNumColsForPrefixKey),
       session.sessionState,
       Some(session.streams.stateStoreCoordinator)) { case (store, iter) =>
 
@@ -652,6 +666,7 @@ case class SessionWindowStateStoreSaveExec(
 
           val removalStartTimeNs = System.nanoTime
           new NextIterator[InternalRow] {
+            // FIXME: can we optimize this case as well?
             private val removedIter = stateManager.removeByValueCondition(
               store, watermarkPredicateForData.get.eval)
 
@@ -751,7 +766,8 @@ case class StreamingDeduplicateExec(
       getStateInfo,
       keyExpressions.toStructType,
       child.output.toStructType,
-      numColsPrefixKey = 0,
+      // FIXME: set event time column!
+      StatefulOperatorContext(),
       session.sessionState,
       Some(session.streams.stateStoreCoordinator),
       // We won't check value row in state store since the value StreamingDeduplicateExec.EMPTY_ROW
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala
index 8bba9b8..ba15f50 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow, SortOrder, UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, Partitioning}
 import org.apache.spark.sql.execution.{LimitExec, SparkPlan, UnaryExecNode}
-import org.apache.spark.sql.execution.streaming.state.StateStoreOps
+import org.apache.spark.sql.execution.streaming.state.{StatefulOperatorContext, StateStoreOps}
 import org.apache.spark.sql.streaming.OutputMode
 import org.apache.spark.sql.types.{LongType, NullType, StructField, StructType}
 import org.apache.spark.util.{CompletionIterator, NextIterator}
@@ -52,7 +52,8 @@ case class StreamingGlobalLimitExec(
         getStateInfo,
         keySchema,
         valueSchema,
-        numColsPrefixKey = 0,
+        // FIXME: set event time column!
+        StatefulOperatorContext(),
         session.sessionState,
         Some(session.streams.stateStoreCoordinator)) { (store, iter) =>
       val key = UnsafeProjection.create(keySchema)(new GenericInternalRow(Array[Any](null)))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBenchmark.scala
new file mode 100644
index 0000000..9a83f5c
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBenchmark.scala
@@ -0,0 +1,633 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.benchmark
+
+import java.{util => jutil}
+
+import scala.util.Random
+
+import org.apache.hadoop.conf.Configuration
+import org.rocksdb.RocksDBException
+
+import org.apache.spark.benchmark.Benchmark
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider, StatefulOperatorContext, StateStore, StateStoreConf, StateStoreId, StateStoreProvider}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType, TimestampType}
+import org.apache.spark.util.Utils
+
+/**
+ * Synthetic benchmark for State Store operations.
+ * To run this benchmark:
+ * {{{
+ *   1. without sbt:
+ *      bin/spark-submit --class <this class>
+ *        --jars <spark core test jar>,<spark catalyst test jar> <sql core test jar>
+ *   2. build/sbt "sql/test:runMain <this class>"
+ *   3. generate result:
+ *      SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain <this class>"
+ *      Results will be written to "benchmarks/StateStoreBenchmark-results.txt".
+ * }}}
+ */
+object StateStoreBenchmark extends SqlBasedBenchmark {
+
+  private val numOfRows: Seq[Int] = Seq(10000, 50000, 100000) // Seq(10000, 100000, 1000000)
+
+  // 200%, 100%, 50%, 25%, 10%, 5%, 1%, no update
+  // rate is relative to the number of rows in prev. batch
+  private val updateRates: Seq[Int] = Seq(25, 10, 5) // Seq(200, 100, 50, 25, 10, 5, 1, 0)
+
+  // 100%, 50%, 25%, 10%, 5%, 1%, no evict
+  // rate is relative to the number of rows in prev. batch
+  private val evictRates: Seq[Int] = Seq(100, 50, 25, 10, 5, 1, 0)
+
+  private val keySchema = StructType(
+    Seq(StructField("key1", IntegerType, true), StructField("key2", TimestampType, true)))
+  private val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
+
+  private val keyProjection = UnsafeProjection.create(keySchema)
+  private val valueProjection = UnsafeProjection.create(valueSchema)
+
+  private def runEvictBenchmark(): Unit = {
+    runBenchmark("evict rows") {
+      val numOfRows = Seq(10000) // Seq(1000, 10000, 100000)
+      val numOfTimestamps = Seq(10, 100, 1000)
+      val numOfEvictionRates = Seq(50, 25, 10, 5, 1, 0) // Seq(100, 75, 50, 25, 1, 0)
+
+      numOfRows.foreach { numOfRow =>
+        numOfTimestamps.foreach { numOfTimestamp =>
+          val timestampsInMicros = (0L until numOfTimestamp).map(ts => ts * 1000L).toList
+
+          val testData = constructRandomizedTestData(numOfRow, timestampsInMicros, 0)
+
+          val rocksDBProvider = newRocksDBStateProviderWithEventTimeIdx()
+          val rocksDBStore = rocksDBProvider.getStore(0)
+          updateRows(rocksDBStore, testData)
+
+          val committedVersion = try {
+            rocksDBStore.commit()
+          } catch {
+            case exc: RocksDBException =>
+              // scalastyle:off println
+              System.out.println(s"Exception in RocksDB happen! ${exc.getMessage} / " +
+                s"status: ${exc.getStatus.getState} / ${exc.getStatus.getCodeString}" )
+              exc.printStackTrace()
+              throw exc
+              // scalastyle:on println
+          }
+
+          numOfEvictionRates.foreach { numOfEvictionRate =>
+            val numOfRowsToEvict = numOfRow * numOfEvictionRate / 100
+            // scalastyle:off println
+            System.out.println(s"numOfRowsToEvict: $numOfRowsToEvict / " +
+              s"timestampsInMicros: $timestampsInMicros / " +
+              s"numOfEvictionRate: $numOfEvictionRate / " +
+              s"numOfTimestamp: $numOfTimestamp / " +
+              s"take: ${numOfTimestamp * numOfEvictionRate / 100}")
+
+            // scalastyle:on println
+            val maxTimestampToEvictInMillis = timestampsInMicros
+              .take(numOfTimestamp * numOfEvictionRate / 100)
+              .lastOption.map(_ / 1000).getOrElse(-1L)
+
+            val benchmark = new Benchmark(s"evicting $numOfRowsToEvict rows " +
+              s"(max timestamp to evict in millis: $maxTimestampToEvictInMillis) " +
+              s"from $numOfRow rows with $numOfTimestamp timestamps " +
+              s"(${numOfRow / numOfTimestamp} rows" +
+              s" for the same timestamp)",
+              numOfRow, minNumIters = 1000, output = output)
+
+            benchmark.addTimerCase("RocksDBStateStoreProvider") { timer =>
+              val rocksDBStore = rocksDBProvider.getStore(committedVersion)
+
+              timer.startTiming()
+              evictAsFullScanAndRemove(rocksDBStore, maxTimestampToEvictInMillis)
+              timer.stopTiming()
+
+              rocksDBStore.abort()
+            }
+
+            benchmark.addTimerCase("RocksDBStateStoreProvider with event time idx") { timer =>
+              val rocksDBStore = rocksDBProvider.getStore(committedVersion)
+
+              timer.startTiming()
+              evictAsNewEvictApi(rocksDBStore, maxTimestampToEvictInMillis)
+              timer.stopTiming()
+
+              rocksDBStore.abort()
+            }
+
+            benchmark.run()
+          }
+
+          rocksDBProvider.close()
+        }
+      }
+    }
+  }
+
+  private def runPutBenchmark(): Unit = {
+    runBenchmark("put rows") {
+      val numOfRows = Seq(10000) // Seq(1000, 10000, 100000)
+      val numOfTimestamps = Seq(100, 1000, 10000) // Seq(1, 10, 100, 1000, 10000)
+      numOfRows.foreach { numOfRow =>
+        numOfTimestamps.foreach { numOfTimestamp =>
+          val timestamps = (0L until numOfTimestamp).map(ts => ts * 1000L).toList
+
+          val testData = constructRandomizedTestData(numOfRow, timestamps, 0)
+
+          val rocksDBProvider = newRocksDBStateProvider()
+          val rocksDBWithIdxProvider = newRocksDBStateProviderWithEventTimeIdx()
+
+          val benchmark = new Benchmark(s"putting $numOfRow rows, with $numOfTimestamp " +
+            s"timestamps (${numOfRow / numOfTimestamp} rows for the same timestamp)",
+            numOfRow, minNumIters = 1000, output = output)
+
+          benchmark.addTimerCase("RocksDBStateStoreProvider") { timer =>
+            val rocksDBStore = rocksDBProvider.getStore(0)
+
+            timer.startTiming()
+            updateRows(rocksDBStore, testData)
+            timer.stopTiming()
+
+            rocksDBStore.abort()
+          }
+
+          benchmark.addTimerCase("RocksDBStateStoreProvider with event time idx") { timer =>
+            val rocksDBWithIdxStore = rocksDBWithIdxProvider.getStore(0)
+
+            timer.startTiming()
+            updateRows(rocksDBWithIdxStore, testData)
+            timer.stopTiming()
+
+            rocksDBWithIdxStore.abort()
+          }
+
+          benchmark.run()
+
+          rocksDBProvider.close()
+          rocksDBWithIdxProvider.close()
+        }
+      }
+    }
+  }
+
+  override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
+    // runPutBenchmark()
+    runEvictBenchmark()
+
+    /*
+    val testData = constructRandomizedTestData(numOfRows.max)
+
+    skip("scanning and comparing") {
+      numOfRows.foreach { numOfRow =>
+        val curData = testData.take(numOfRow)
+
+        val inMemoryProvider = newHDFSBackedStateStoreProvider()
+        val inMemoryStore = inMemoryProvider.getStore(0)
+
+        val rocksDBProvider = newRocksDBStateProvider()
+        val rocksDBStore = rocksDBProvider.getStore(0)
+
+        updateRows(inMemoryStore, curData)
+        updateRows(rocksDBStore, curData)
+
+        val newVersionForInMemory = inMemoryStore.commit()
+        val newVersionForRocksDB = rocksDBStore.commit()
+
+        val benchmark = new Benchmark(s"scanning and comparing $numOfRow rows",
+          numOfRow, minNumIters = 1000, output = output)
+
+        benchmark.addTimerCase("HDFSBackedStateStoreProvider") { timer =>
+          val inMemoryStore2 = inMemoryProvider.getStore(newVersionForInMemory)
+
+          timer.startTiming()
+          // NOTE: the latency would be quite similar regardless of the rate of eviction
+          // as we don't remove the actual row, so I simply picked 10 %
+          fullScanAndCompareTimestamp(inMemoryStore2, (numOfRow * 0.1).toInt)
+          timer.stopTiming()
+
+          inMemoryStore2.abort()
+        }
+
+        benchmark.addTimerCase("RocksDBStateStoreProvider") { timer =>
+          val rocksDBStore2 = rocksDBProvider.getStore(newVersionForRocksDB)
+
+          timer.startTiming()
+          // NOTE: the latency would be quite similar regardless of the rate of eviction
+          // as we don't remove the actual row, so I simply picked 10 %
+          fullScanAndCompareTimestamp(rocksDBStore2, (numOfRow * 0.1).toInt)
+          timer.stopTiming()
+
+          rocksDBStore2.abort()
+        }
+
+        benchmark.run()
+
+        inMemoryProvider.close()
+        rocksDBProvider.close()
+      }
+    }
+
+    // runBenchmark("simulate full operations on eviction") {
+    skip("simulate full operations on eviction") {
+      numOfRows.foreach { numOfRow =>
+        val curData = testData.take(numOfRow)
+
+        val inMemoryProvider = newHDFSBackedStateStoreProvider()
+        val inMemoryStore = inMemoryProvider.getStore(0)
+
+        val rocksDBProvider = newRocksDBStateProvider()
+        val rocksDBStore = rocksDBProvider.getStore(0)
+
+        val indexForInMemoryStore = new jutil.concurrent.ConcurrentSkipListMap[
+          Int, jutil.List[UnsafeRow]]()
+        val indexForRocksDBStore = new jutil.concurrent.ConcurrentSkipListMap[
+          Int, jutil.List[UnsafeRow]]()
+
+        updateRowsWithSortedMapIndex(inMemoryStore, indexForInMemoryStore, curData)
+        updateRowsWithSortedMapIndex(rocksDBStore, indexForRocksDBStore, curData)
+
+        assert(indexForInMemoryStore.size() == numOfRow)
+        assert(indexForRocksDBStore.size() == numOfRow)
+
+        val newVersionForInMemory = inMemoryStore.commit()
+        val newVersionForRocksDB = rocksDBStore.commit()
+
+        val rowsToUpdate = constructRandomizedTestData(numOfRow / 100 * updateRates.max,
+          minIdx = numOfRow + 1)
+
+        updateRates.foreach { updateRate =>
+          val numRowsUpdate = numOfRow / 100 * updateRate
+          val curRowsToUpdate = rowsToUpdate.take(numRowsUpdate)
+
+          evictRates.foreach { evictRate =>
+            val maxIdxToEvict = numOfRow / 100 * evictRate
+
+            val benchmark = new Benchmark(s"simulating evict on $numOfRow rows, update " +
+              s"$numRowsUpdate rows ($updateRate %), evict $maxIdxToEvict rows ($evictRate %)",
+              numOfRow, minNumIters = 100, output = output)
+
+            benchmark.addTimerCase("HDFSBackedStateStoreProvider") { timer =>
+              val inMemoryStore2 = inMemoryProvider.getStore(newVersionForInMemory)
+
+              timer.startTiming()
+              updateRows(inMemoryStore2, curRowsToUpdate)
+              evictAsFullScanAndRemove(inMemoryStore2, maxIdxToEvict)
+              timer.stopTiming()
+
+              inMemoryStore2.abort()
+            }
+
+            benchmark.addTimerCase("HDFSBackedStateStoreProvider - sorted map index") { timer =>
+
+              val inMemoryStore2 = inMemoryProvider.getStore(newVersionForInMemory)
+
+              val curIndex = new jutil.concurrent.ConcurrentSkipListMap[Int,
+                jutil.List[UnsafeRow]]()
+              curIndex.putAll(indexForInMemoryStore)
+
+              assert(curIndex.size() == numOfRow)
+
+              timer.startTiming()
+              updateRowsWithSortedMapIndex(inMemoryStore2, curIndex, curRowsToUpdate)
+
+              assert(curIndex.size() == numOfRow + curRowsToUpdate.size)
+
+              evictAsScanSortedMapIndexAndRemove(inMemoryStore2, curIndex, maxIdxToEvict)
+              timer.stopTiming()
+
+              assert(curIndex.size() == numOfRow + curRowsToUpdate.size - maxIdxToEvict)
+
+              curIndex.clear()
+
+              inMemoryStore2.abort()
+            }
+
+            benchmark.run()
+
+            val benchmark2 = new Benchmark(s"simulating evict on $numOfRow rows, update " +
+              s"$numRowsUpdate rows ($updateRate %), evict $maxIdxToEvict rows ($evictRate %)",
+              numOfRow, minNumIters = 100, output = output)
+
+            benchmark2.addTimerCase("RocksDBStateStoreProvider") { timer =>
+              val rocksDBStore2 = rocksDBProvider.getStore(newVersionForRocksDB)
+
+              timer.startTiming()
+              updateRows(rocksDBStore2, curRowsToUpdate)
+              evictAsFullScanAndRemove(rocksDBStore2, maxIdxToEvict)
+              timer.stopTiming()
+
+              rocksDBStore2.abort()
+            }
+
+            benchmark2.addTimerCase("RocksDBStateStoreProvider - sorted map index") { timer =>
+
+              val rocksDBStore2 = rocksDBProvider.getStore(newVersionForRocksDB)
+
+              val curIndex = new jutil.concurrent.ConcurrentSkipListMap[Int,
+                jutil.List[UnsafeRow]]()
+              curIndex.putAll(indexForRocksDBStore)
+
+              assert(curIndex.size() == numOfRow)
+
+              timer.startTiming()
+              updateRowsWithSortedMapIndex(rocksDBStore2, curIndex, curRowsToUpdate)
+
+              assert(curIndex.size() == numOfRow + curRowsToUpdate.size)
+
+              evictAsScanSortedMapIndexAndRemove(rocksDBStore2, curIndex, maxIdxToEvict)
+              timer.stopTiming()
+
+              assert(curIndex.size() == numOfRow + curRowsToUpdate.size - maxIdxToEvict)
+
+              curIndex.clear()
+
+              rocksDBStore2.abort()
+            }
+
+            benchmark2.run()
+          }
+        }
+
+        inMemoryProvider.close()
+        rocksDBProvider.close()
+      }
+    }
+
+    // runBenchmark("simulate full operations on eviction") {
+    skip("simulate full operations on eviction - scannable index") {
+      numOfRows.foreach { numOfRow =>
+        val curData = testData.take(numOfRow)
+
+        val rocksDBProvider = newRocksDBStateProvider()
+        val rocksDBStore = rocksDBProvider.getStore(0)
+
+        val rocksDBWithIdxProvider = newRocksDBStateProviderWithEventTimeIdx()
+        val rocksDBWithIdxStore = rocksDBWithIdxProvider.getStore(0)
+
+        updateRows(rocksDBStore, curData)
+        updateRows(rocksDBWithIdxStore, curData)
+
+        val newVersionForRocksDB = rocksDBStore.commit()
+        val newVersionForRocksDBWithIdx = rocksDBWithIdxStore.commit()
+
+        val rowsToUpdate = constructRandomizedTestData(numOfRow / 100 * updateRates.max,
+          minIdx = numOfRow + 1)
+
+        updateRates.foreach { updateRate =>
+          val numRowsUpdate = numOfRow / 100 * updateRate
+          val curRowsToUpdate = rowsToUpdate.take(numRowsUpdate)
+
+          evictRates.foreach { evictRate =>
+            val maxIdxToEvict = numOfRow / 100 * evictRate
+
+            val benchmark = new Benchmark(s"simulating evict on $numOfRow rows, update " +
+              s"$numRowsUpdate rows ($updateRate %), evict $maxIdxToEvict rows ($evictRate %)",
+              numOfRow, minNumIters = 100, output = output)
+
+            benchmark.addTimerCase("RocksDBStateStoreProvider") { timer =>
+              val rocksDBStore2 = rocksDBProvider.getStore(newVersionForRocksDB)
+
+              timer.startTiming()
+              updateRows(rocksDBStore2, curRowsToUpdate)
+              evictAsFullScanAndRemove(rocksDBStore2, maxIdxToEvict)
+              // evictAsNewEvictApi(rocksDBStore2, maxIdxToEvict)
+              timer.stopTiming()
+
+              rocksDBStore2.abort()
+            }
+
+            benchmark.addTimerCase("RocksDBStateStoreProvider with event time idx") { timer =>
+              val rocksDBWithIdxStore2 = rocksDBWithIdxProvider.getStore(
+                newVersionForRocksDBWithIdx)
+
+              timer.startTiming()
+              updateRows(rocksDBWithIdxStore2, curRowsToUpdate)
+              evictAsNewEvictApi(rocksDBWithIdxStore2, maxIdxToEvict)
+              timer.stopTiming()
+
+              rocksDBWithIdxStore2.abort()
+            }
+
+            benchmark.run()
+          }
+        }
+
+        rocksDBProvider.close()
+        rocksDBWithIdxProvider.close()
+      }
+    }
+
+    runBenchmark("put rows") {
+      numOfRows.foreach { numOfRow =>
+        val curData = testData.take(numOfRow)
+
+        val rocksDBProvider = newRocksDBStateProvider()
+        val rocksDBWithIdxProvider = newRocksDBStateProviderWithEventTimeIdx()
+
+        val benchmark = new Benchmark(s"putting $numOfRow rows",
+          numOfRow, minNumIters = 1000, output = output)
+
+        benchmark.addTimerCase("RocksDBStateStoreProvider") { timer =>
+          val rocksDBStore = rocksDBProvider.getStore(0)
+
+          timer.startTiming()
+          updateRows(rocksDBStore, curData)
+          timer.stopTiming()
+
+          rocksDBStore.abort()
+        }
+
+        benchmark.addTimerCase("RocksDBStateStoreProvider with event time idx") { timer =>
+          val rocksDBWithIdxStore = rocksDBWithIdxProvider.getStore(0)
+
+          timer.startTiming()
+          updateRows(rocksDBWithIdxStore, curData)
+          timer.stopTiming()
+
+          rocksDBWithIdxStore.abort()
+        }
+
+        benchmark.run()
+
+        rocksDBProvider.close()
+        rocksDBWithIdxProvider.close()
+      }
+    }
+     */
+  }
+
+  final def skip(benchmarkName: String)(func: => Any): Unit = {
+    output.foreach(_.write(s"$benchmarkName is skipped".getBytes))
+  }
+
+  private def updateRows(
+      store: StateStore,
+      rows: Seq[(UnsafeRow, UnsafeRow)]): Unit = {
+    rows.foreach { case (key, value) =>
+      store.put(key, value)
+    }
+  }
+
+  private def evictAsFullScanAndRemove(
+      store: StateStore,
+      maxTimestampToEvict: Long): Unit = {
+    store.iterator().foreach { r =>
+      if (r.key.getLong(1) < maxTimestampToEvict) {
+        store.remove(r.key)
+      }
+    }
+  }
+
+  private def evictAsNewEvictApi(
+      store: StateStore,
+      maxTimestampToEvict: Long): Unit = {
+    store.evictOnWatermark(maxTimestampToEvict, pair => {
+      pair.key.getLong(1) < maxTimestampToEvict
+    }).foreach { _ => }
+  }
+
+  private def fullScanAndCompareTimestamp(
+      store: StateStore,
+      maxIdxToEvict: Int): Unit = {
+    var i: Long = 0
+    store.iterator().foreach { r =>
+      if (r.key.getInt(1) < maxIdxToEvict) {
+        // simply to avoid the "if statement" to be no-op
+        i += 1
+      }
+    }
+  }
+
+  private def updateRowsWithSortedMapIndex(
+      store: StateStore,
+      index: jutil.SortedMap[Int, jutil.List[UnsafeRow]],
+      rows: Seq[(UnsafeRow, UnsafeRow)]): Unit = {
+    rows.foreach { case (key, value) =>
+      val idx = key.getInt(1)
+
+      // TODO: rewrite this in atomic way?
+      if (index.containsKey(idx)) {
+        val list = index.get(idx)
+        list.add(key)
+      } else {
+        val list = new jutil.ArrayList[UnsafeRow]()
+        list.add(key)
+        index.put(idx, list)
+      }
+
+      store.put(key, value)
+    }
+  }
+
+  private def evictAsScanSortedMapIndexAndRemove(
+      store: StateStore,
+      index: jutil.SortedMap[Int, jutil.List[UnsafeRow]],
+      maxIdxToEvict: Int): Unit = {
+    val keysToRemove = index.headMap(maxIdxToEvict + 1)
+    val keysToRemoveIter = keysToRemove.entrySet().iterator()
+    while (keysToRemoveIter.hasNext) {
+      val entry = keysToRemoveIter.next()
+      val keys = entry.getValue
+      val keysIter = keys.iterator()
+      while (keysIter.hasNext) {
+        val key = keysIter.next()
+        store.remove(key)
+      }
+      keys.clear()
+      keysToRemoveIter.remove()
+    }
+  }
+
+  // FIXME: should the size of key / value be variables?
+  private def constructTestData(numRows: Int, minIdx: Int = 0): Seq[(UnsafeRow, UnsafeRow)] = {
+    (1 to numRows).map { idx =>
+      val keyRow = new GenericInternalRow(2)
+      keyRow.setInt(0, 1)
+      keyRow.setLong(1, (minIdx + idx) * 1000L) // microseconds
+      val valueRow = new GenericInternalRow(1)
+      valueRow.setInt(0, minIdx + idx)
+
+      val keyUnsafeRow = keyProjection(keyRow).copy()
+      val valueUnsafeRow = valueProjection(valueRow).copy()
+
+      (keyUnsafeRow, valueUnsafeRow)
+    }
+  }
+
+  // This prevents created keys to be in order, which may affect the performance on RocksDB.
+  private def constructRandomizedTestData(
+      numRows: Int,
+      timestamps: List[Long],
+      minIdx: Int = 0): Seq[(UnsafeRow, UnsafeRow)] = {
+    assert(numRows >= timestamps.length)
+    assert(numRows % timestamps.length == 0)
+
+    (1 to numRows).map { idx =>
+      val keyRow = new GenericInternalRow(2)
+      keyRow.setInt(0, Random.nextInt(Int.MaxValue))
+      keyRow.setLong(1, timestamps((minIdx + idx) % timestamps.length)) // microseconds
+      val valueRow = new GenericInternalRow(1)
+      valueRow.setInt(0, minIdx + idx)
+
+      val keyUnsafeRow = keyProjection(keyRow).copy()
+      val valueUnsafeRow = valueProjection(valueRow).copy()
+
+      (keyUnsafeRow, valueUnsafeRow)
+    }
+  }
+
+  private def newHDFSBackedStateStoreProvider(): StateStoreProvider = {
+    val storeId = StateStoreId(newDir(), Random.nextInt(), 0)
+    val provider = new HDFSBackedStateStoreProvider()
+    val sqlConf = new SQLConf()
+    sqlConf.setConfString("spark.sql.streaming.stateStore.compression.codec", "zstd")
+    val storeConf = new StateStoreConf(sqlConf)
+    provider.init(
+      storeId, keySchema, valueSchema, StatefulOperatorContext(),
+      storeConf, new Configuration)
+    provider
+  }
+
+  private def newRocksDBStateProvider(): StateStoreProvider = {
+    val storeId = StateStoreId(newDir(), Random.nextInt(), 0)
+    val provider = new RocksDBStateStoreProvider()
+    val sqlConf = new SQLConf()
+    sqlConf.setConfString("spark.sql.streaming.stateStore.compression.codec", "zstd")
+    val storeConf = new StateStoreConf(sqlConf)
+    provider.init(
+      storeId, keySchema, valueSchema, StatefulOperatorContext(),
+      storeConf, new Configuration)
+    provider
+  }
+
+  private def newRocksDBStateProviderWithEventTimeIdx(): StateStoreProvider = {
+    val storeId = StateStoreId(newDir(), Random.nextInt(), 0)
+    val provider = new RocksDBStateStoreProvider()
+    val sqlConf = new SQLConf()
+    sqlConf.setConfString("spark.sql.streaming.stateStore.compression.codec", "zstd")
+    val storeConf = new StateStoreConf(sqlConf)
+    provider.init(
+      storeId, keySchema, valueSchema, StatefulOperatorContext(eventTimeColIdx = Array(1)),
+      storeConf, new Configuration)
+    provider
+  }
+
+  private def newDir(): String = Utils.createTempDir().toString
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala
index 81f1a3f..4f7c040 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala
@@ -24,7 +24,7 @@ import org.scalatest.BeforeAndAfter
 
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
-import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider, StateStore, StateStoreConf, StateStoreId, StateStoreProviderId, StreamingSessionWindowStateManager}
+import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider, StatefulOperatorContext, StateStore, StateStoreConf, StateStoreId, StateStoreProviderId, StreamingSessionWindowStateManager}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.StreamTest
 import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}
@@ -217,9 +217,12 @@ class MergingSortWithSessionWindowStateIteratorSuite extends StreamTest with Bef
         stateFormatVersion)
 
       val storeProviderId = StateStoreProviderId(stateInfo, 0, StateStoreId.DEFAULT_STORE_NAME)
+
+      // FIXME: event time column?
       val store = StateStore.get(
         storeProviderId, manager.getStateKeySchema, manager.getStateValueSchema,
-        manager.getNumColsForPrefixKey, stateInfo.storeVersion, storeConf, new Configuration)
+        StatefulOperatorContext(numColsPrefixKey = manager.getNumColsForPrefixKey),
+        stateInfo.storeVersion, storeConf, new Configuration)
 
       try {
         f(manager, store)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala
index e52ccd0..4b53c0a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala
@@ -50,4 +50,18 @@ class MemoryStateStore extends StateStore() {
   override def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair] = {
     throw new UnsupportedOperationException("Doesn't support prefix scan!")
   }
+
+  /** FIXME: method doc */
+  override def evictOnWatermark(
+      watermarkMs: Long,
+      altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair] = {
+    iterator().filter { pair =>
+      if (altPred.apply(pair)) {
+        remove(pair.key)
+        true
+      } else {
+        false
+      }
+    }
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
index 2d741d3..7374303 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
@@ -24,7 +24,7 @@ import scala.collection.JavaConverters
 import org.scalatest.time.{Minute, Span}
 
 import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper}
-import org.apache.spark.sql.functions.count
+import org.apache.spark.sql.functions.{count, timestamp_seconds, window}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming._
 
@@ -52,6 +52,64 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest {
     }
   }
 
+  test("append mode") {
+    val inputData = MemoryStream[Int]
+    val conf = Map(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName)
+
+    val windowedAggregation = inputData.toDF()
+      .withColumn("eventTime", timestamp_seconds($"value"))
+      .withWatermark("eventTime", "10 seconds")
+      .groupBy(window($"eventTime", "5 seconds") as 'window)
+      .agg(count("*") as 'count)
+      .select($"window".getField("start").cast("long").as[Long], $"count".as[Long])
+
+    testStream(windowedAggregation)(
+      StartStream(additionalConfs = conf),
+      AddData(inputData, 10, 11, 12, 13, 14, 15),
+      CheckNewAnswer(),
+      AddData(inputData, 25),   // Advance watermark to 15 seconds
+      CheckNewAnswer((10, 5)),
+      // assertNumStateRows(2),
+      // assertNumRowsDroppedByWatermark(0),
+      AddData(inputData, 10),   // Should not emit anything as data less than watermark
+      CheckNewAnswer()
+      // assertNumStateRows(2),
+      // assertNumRowsDroppedByWatermark(1)
+    )
+  }
+
+  test("update mode") {
+    val inputData = MemoryStream[Int]
+    val conf = Map(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName)
+
+    val windowedAggregation = inputData.toDF()
+      .withColumn("eventTime", timestamp_seconds($"value"))
+      .withWatermark("eventTime", "10 seconds")
+      .groupBy(window($"eventTime", "5 seconds") as 'window)
+      .agg(count("*") as 'count)
+      .select($"window".getField("start").cast("long").as[Long], $"count".as[Long])
+
+    testStream(windowedAggregation, OutputMode.Update)(
+      StartStream(additionalConfs = conf),
+      AddData(inputData, 10, 11, 12, 13, 14, 15),
+      CheckNewAnswer((10, 5), (15, 1)),
+      AddData(inputData, 25),     // Advance watermark to 15 seconds
+      CheckNewAnswer((25, 1)),
+      // assertNumStateRows(2),
+      // assertNumRowsDroppedByWatermark(0),
+      AddData(inputData, 10, 25), // Ignore 10 as its less than watermark
+      CheckNewAnswer((25, 2)),
+      // assertNumStateRows(2),
+      // assertNumRowsDroppedByWatermark(1),
+      AddData(inputData, 10),     // Should not emit anything as data less than watermark
+      CheckNewAnswer()
+      // assertNumStateRows(2),
+      // assertNumRowsDroppedByWatermark(1)
+    )
+  }
+
   test("SPARK-36236: query progress contains only the expected RocksDB store custom metrics") {
     // fails if any new custom metrics are added to remind the author of API changes
     import testImplicits._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
index c93d0f0..c06a9ec 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
@@ -87,8 +87,9 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid
         queryRunId = UUID.randomUUID, operatorId = 0, storeVersion = 0, numPartitions = 5)
 
       // Create state store in a task and get the RocksDBConf from the instantiated RocksDB instance
+      // FIXME: event time column?
       val rocksDBConfInTask: RocksDBConf = testRDD.mapPartitionsWithStateStore[RocksDBConf](
-        spark.sqlContext, testStateInfo, testSchema, testSchema, 0) {
+        spark.sqlContext, testStateInfo, testSchema, testSchema, StatefulOperatorContext()) {
           (store: StateStore, _: Iterator[String]) =>
             // Use reflection to get RocksDB instance
             val dbInstanceMethod =
@@ -144,7 +145,8 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid
       numColsPrefixKey: Int): RocksDBStateStoreProvider = {
     val provider = new RocksDBStateStoreProvider()
     provider.init(
-      storeId, keySchema, valueSchema, numColsPrefixKey = numColsPrefixKey,
+      storeId, keySchema, valueSchema,
+      StatefulOperatorContext(numColsPrefixKey = numColsPrefixKey),
       new StateStoreConf, new Configuration)
     provider
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
index 6bb8ebe..0e9b19d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
@@ -60,13 +60,13 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
       val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString
       val rdd1 = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("a", 0)))
         .mapPartitionsWithStateStore(spark.sqlContext, operatorStateInfo(path, version = 0),
-          keySchema, valueSchema, numColsPrefixKey = 0)(increment)
+          keySchema, valueSchema, StatefulOperatorContext())(increment)
       assert(rdd1.collect().toSet === Set(("a", 0) -> 2, ("b", 0) -> 1))
 
       // Generate next version of stores
       val rdd2 = makeRDD(spark.sparkContext, Seq(("a", 0), ("c", 0)))
         .mapPartitionsWithStateStore(spark.sqlContext, operatorStateInfo(path, version = 1),
-          keySchema, valueSchema, numColsPrefixKey = 0)(increment)
+          keySchema, valueSchema, StatefulOperatorContext())(increment)
       assert(rdd2.collect().toSet === Set(("a", 0) -> 3, ("b", 0) -> 1, ("c", 0) -> 1))
 
       // Make sure the previous RDD still has the same data.
@@ -84,7 +84,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
       implicit val sqlContext = spark.sqlContext
       makeRDD(spark.sparkContext, Seq(("a", 0))).mapPartitionsWithStateStore(
         sqlContext, operatorStateInfo(path, version = storeVersion),
-        keySchema, valueSchema, numColsPrefixKey = 0)(increment)
+        keySchema, valueSchema, StatefulOperatorContext())(increment)
     }
 
     // Generate RDDs and state store data
@@ -134,19 +134,19 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
 
       val rddOfGets1 = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("c", 0)))
         .mapPartitionsWithStateStore(spark.sqlContext, operatorStateInfo(path, version = 0),
-          keySchema, valueSchema, numColsPrefixKey = 0)(iteratorOfGets)
+          keySchema, valueSchema, StatefulOperatorContext())(iteratorOfGets)
       assert(rddOfGets1.collect().toSet ===
         Set(("a", 0) -> None, ("b", 0) -> None, ("c", 0) -> None))
 
       val rddOfPuts = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("a", 0)))
         .mapPartitionsWithStateStore(sqlContext, operatorStateInfo(path, version = 0),
-          keySchema, valueSchema, numColsPrefixKey = 0)(iteratorOfPuts)
+          keySchema, valueSchema, StatefulOperatorContext())(iteratorOfPuts)
       assert(rddOfPuts.collect().toSet ===
         Set(("a", 0) -> 1, ("a", 0) -> 2, ("b", 0) -> 1))
 
       val rddOfGets2 = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("c", 0)))
         .mapPartitionsWithStateStore(sqlContext, operatorStateInfo(path, version = 1),
-          keySchema, valueSchema, numColsPrefixKey = 0)(iteratorOfGets)
+          keySchema, valueSchema, StatefulOperatorContext())(iteratorOfGets)
       assert(rddOfGets2.collect().toSet ===
         Set(("a", 0) -> Some(2), ("b", 0) -> Some(1), ("c", 0) -> None))
     }
@@ -172,7 +172,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
 
         val rdd = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("a", 0)))
           .mapPartitionsWithStateStore(sqlContext, operatorStateInfo(path, queryRunId = queryRunId),
-          keySchema, valueSchema, numColsPrefixKey = 0)(increment)
+          keySchema, valueSchema, StatefulOperatorContext())(increment)
         require(rdd.partitions.length === 2)
 
         assert(
@@ -200,13 +200,13 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
         val opId = 0
         val rdd1 = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("a", 0)))
           .mapPartitionsWithStateStore(sqlContext, operatorStateInfo(path, version = 0),
-            keySchema, valueSchema, numColsPrefixKey = 0)(increment)
+            keySchema, valueSchema, StatefulOperatorContext())(increment)
         assert(rdd1.collect().toSet === Set(("a", 0) -> 2, ("b", 0) -> 1))
 
         // Generate next version of stores
         val rdd2 = makeRDD(spark.sparkContext, Seq(("a", 0), ("c", 0)))
           .mapPartitionsWithStateStore(sqlContext, operatorStateInfo(path, version = 1),
-            keySchema, valueSchema, numColsPrefixKey = 0)(increment)
+            keySchema, valueSchema, StatefulOperatorContext())(increment)
         assert(rdd2.collect().toSet === Set(("a", 0) -> 3, ("b", 0) -> 1, ("c", 0) -> 1))
 
         // Make sure the previous RDD still has the same data.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index 601b62b..d89c6fed 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -270,8 +270,8 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
 
     def generateStoreVersions(): Unit = {
       for (i <- 1 to 20) {
-        val store = StateStore.get(storeProviderId1, keySchema, valueSchema, numColsPrefixKey = 0,
-          latestStoreVersion, storeConf, hadoopConf)
+        val store = StateStore.get(storeProviderId1, keySchema, valueSchema,
+          StatefulOperatorContext(), latestStoreVersion, storeConf, hadoopConf)
         put(store, "a", 0, i)
         store.commit()
         latestStoreVersion += 1
@@ -324,7 +324,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
           }
 
           // Reload the store and verify
-          StateStore.get(storeProviderId1, keySchema, valueSchema, numColsPrefixKey = 0,
+          StateStore.get(storeProviderId1, keySchema, valueSchema, StatefulOperatorContext(),
             latestStoreVersion, storeConf, hadoopConf)
           assert(StateStore.isLoaded(storeProviderId1))
 
@@ -336,7 +336,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
           }
 
           // Reload the store and verify
-          StateStore.get(storeProviderId1, keySchema, valueSchema, numColsPrefixKey = 0,
+          StateStore.get(storeProviderId1, keySchema, valueSchema, StatefulOperatorContext(),
             latestStoreVersion, storeConf, hadoopConf)
           assert(StateStore.isLoaded(storeProviderId1))
 
@@ -344,7 +344,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
           // then this executor should unload inactive instances immediately.
           coordinatorRef
             .reportActiveInstance(storeProviderId1, "other-host", "other-exec", Seq.empty)
-          StateStore.get(storeProviderId2, keySchema, valueSchema, numColsPrefixKey = 0,
+          StateStore.get(storeProviderId2, keySchema, valueSchema, StatefulOperatorContext(),
             0, storeConf, hadoopConf)
           assert(!StateStore.isLoaded(storeProviderId1))
           assert(StateStore.isLoaded(storeProviderId2))
@@ -453,7 +453,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     // Getting the store should not create temp file
     val store0 = shouldNotCreateTempFile {
       StateStore.get(
-        storeId, keySchema, valueSchema, numColsPrefixKey = 0,
+        storeId, keySchema, valueSchema, StatefulOperatorContext(),
         version = 0, storeConf, hadoopConf)
     }
 
@@ -470,7 +470,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     // Remove should create a temp file
     val store1 = shouldNotCreateTempFile {
       StateStore.get(
-        storeId, keySchema, valueSchema, numColsPrefixKey = 0,
+        storeId, keySchema, valueSchema, StatefulOperatorContext(),
         version = 1, storeConf, hadoopConf)
     }
     remove(store1, _._1 == "a")
@@ -485,7 +485,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     // Commit without any updates should create a delta file
     val store2 = shouldNotCreateTempFile {
       StateStore.get(
-        storeId, keySchema, valueSchema, numColsPrefixKey = 0,
+        storeId, keySchema, valueSchema, StatefulOperatorContext(),
         version = 2, storeConf, hadoopConf)
     }
     store2.commit()
@@ -720,11 +720,12 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
       hadoopConf: Configuration = new Configuration): HDFSBackedStateStoreProvider = {
     val sqlConf = getDefaultSQLConf(minDeltasForSnapshot, numOfVersToRetainInMemory)
     val provider = new HDFSBackedStateStoreProvider()
+    // FIXME: event time column?
     provider.init(
       StateStoreId(dir, opId, partition),
       keySchema,
       valueSchema,
-      numColsPrefixKey = numColsPrefixKey,
+      StatefulOperatorContext(numColsPrefixKey = numColsPrefixKey),
       new StateStoreConf(sqlConf),
       hadoopConf)
     provider
@@ -1027,31 +1028,31 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
       // Verify that trying to get incorrect versions throw errors
       intercept[IllegalArgumentException] {
         StateStore.get(
-          storeId, keySchema, valueSchema, 0, -1, storeConf, hadoopConf)
+          storeId, keySchema, valueSchema, StatefulOperatorContext(), -1, storeConf, hadoopConf)
       }
       assert(!StateStore.isLoaded(storeId)) // version -1 should not attempt to load the store
 
       intercept[IllegalStateException] {
         StateStore.get(
-          storeId, keySchema, valueSchema, 0, 1, storeConf, hadoopConf)
+          storeId, keySchema, valueSchema, StatefulOperatorContext(), 1, storeConf, hadoopConf)
       }
 
       // Increase version of the store and try to get again
       val store0 = StateStore.get(
-        storeId, keySchema, valueSchema, 0, 0, storeConf, hadoopConf)
+        storeId, keySchema, valueSchema, StatefulOperatorContext(), 0, storeConf, hadoopConf)
       assert(store0.version === 0)
       put(store0, "a", 0, 1)
       store0.commit()
 
       val store1 = StateStore.get(
-        storeId, keySchema, valueSchema, 0, 1, storeConf, hadoopConf)
+        storeId, keySchema, valueSchema, StatefulOperatorContext(), 1, storeConf, hadoopConf)
       assert(StateStore.isLoaded(storeId))
       assert(store1.version === 1)
       assert(rowPairsToDataSet(store1.iterator()) === Set(("a", 0) -> 1))
 
       // Verify that you can also load older version
       val store0reloaded = StateStore.get(
-        storeId, keySchema, valueSchema, 0, 0, storeConf, hadoopConf)
+        storeId, keySchema, valueSchema, StatefulOperatorContext(), 0, storeConf, hadoopConf)
       assert(store0reloaded.version === 0)
       assert(rowPairsToDataSet(store0reloaded.iterator()) === Set.empty)
 
@@ -1060,7 +1061,7 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
       assert(!StateStore.isLoaded(storeId))
 
       val store1reloaded = StateStore.get(
-        storeId, keySchema, valueSchema, 0, 1, storeConf, hadoopConf)
+        storeId, keySchema, valueSchema, StatefulOperatorContext(), 1, storeConf, hadoopConf)
       assert(StateStore.isLoaded(storeId))
       assert(store1reloaded.version === 1)
       put(store1reloaded, "a", 0, 2)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManagerSuite.scala
index 096c3bb..dcdbac9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManagerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManagerSuite.scala
@@ -181,9 +181,11 @@ class StreamingSessionWindowStateManagerSuite extends StreamTest with BeforeAndA
         stateFormatVersion)
 
       val storeProviderId = StateStoreProviderId(stateInfo, 0, StateStoreId.DEFAULT_STORE_NAME)
+      // FIXME: event time column?
       val store = StateStore.get(
         storeProviderId, manager.getStateKeySchema, manager.getStateValueSchema,
-        manager.getNumColsForPrefixKey, stateInfo.storeVersion, storeConf, new Configuration)
+        StatefulOperatorContext(numColsPrefixKey = manager.getNumColsForPrefixKey),
+        stateInfo.storeVersion, storeConf, new Configuration)
 
       try {
         f(manager, store)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index e89197b..d376942 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.{LocalLimitExec, SimpleMode, SparkPlan}
 import org.apache.spark.sql.execution.command.ExplainCommand
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.execution.streaming.sources.{ContinuousMemoryStream, MemorySink}
-import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider}
+import org.apache.spark.sql.execution.streaming.state.{StatefulOperatorContext, StateStore, StateStoreConf, StateStoreId, StateStoreProvider}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.sources.StreamSourceProvider
@@ -1418,7 +1418,7 @@ class TestStateStoreProvider extends StateStoreProvider {
       stateStoreId: StateStoreId,
       keySchema: StructType,
       valueSchema: StructType,
-      numColsPrefixKey: Int,
+      operatorContext: StatefulOperatorContext,
       storeConfs: StateStoreConf,
       hadoopConf: Configuration): Unit = {
     throw new Exception("Successfully instantiated")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
index 77334ad..fbf3aae 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
 import org.apache.spark.sql.execution.exchange.Exchange
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.execution.streaming.sources.MemorySink
-import org.apache.spark.sql.execution.streaming.state.{StateSchemaNotCompatible, StateStore, StreamingAggregationStateManager}
+import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider, StateSchemaNotCompatible, StateStore, StreamingAggregationStateManager}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.OutputMode._
@@ -53,29 +53,47 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions {
   import testImplicits._
 
   def executeFuncWithStateVersionSQLConf(
+      providerCls: String,
       stateVersion: Int,
       confPairs: Seq[(String, String)],
       func: => Any): Unit = {
     withSQLConf(confPairs ++
-      Seq(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> stateVersion.toString): _*) {
+      Seq(
+        SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> stateVersion.toString,
+        SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerCls.stripSuffix("$")): _*) {
       func
     }
   }
 
   def testWithAllStateVersions(name: String, confPairs: (String, String)*)
                               (func: => Any): Unit = {
-    for (version <- StreamingAggregationStateManager.supportedVersions) {
-      test(s"$name - state format version $version") {
-        executeFuncWithStateVersionSQLConf(version, confPairs, func)
+    val providers = Seq(
+      // FIXME: testing...
+      // classOf[HDFSBackedStateStoreProvider].getCanonicalName,
+      classOf[RocksDBStateStoreProvider].getCanonicalName)
+
+    for (
+      version <- StreamingAggregationStateManager.supportedVersions;
+      provider <- providers
+    ) yield {
+      test(s"$name - state format version $version / provider: $provider") {
+        executeFuncWithStateVersionSQLConf(provider, version, confPairs, func)
       }
     }
   }
 
   def testQuietlyWithAllStateVersions(name: String, confPairs: (String, String)*)
                                      (func: => Any): Unit = {
-    for (version <- StreamingAggregationStateManager.supportedVersions) {
-      testQuietly(s"$name - state format version $version") {
-        executeFuncWithStateVersionSQLConf(version, confPairs, func)
+    val providers = Seq(
+      classOf[HDFSBackedStateStoreProvider].getCanonicalName,
+      classOf[RocksDBStateStoreProvider].getCanonicalName)
+
+    for (
+      version <- StreamingAggregationStateManager.supportedVersions;
+      provider <- providers
+    ) yield {
+      testQuietly(s"$name - state format version $version / provider: $provider") {
+        executeFuncWithStateVersionSQLConf(provider, version, confPairs, func)
       }
     }
   }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org