You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ka...@apache.org on 2021/11/05 07:42:59 UTC

[spark] 02/02: WIP still need to add e2e test and address FIXME/TODOs

This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch WIP-optimize-eviction-in-rocksdb-state-store
in repository https://gitbox.apache.org/repos/asf/spark.git

commit 88187c64d70a5566bb8f07ce4b133a33e63ce5fc
Author: Jungtaek Lim <ka...@gmail.com>
AuthorDate: Fri Nov 5 16:41:24 2021 +0900

    WIP still need to add e2e test and address FIXME/TODOs
---
 .../sql/execution/streaming/state/RocksDB.scala    | 172 ++++-------------
 .../streaming/state/RocksDBFileManager.scala       |  35 +++-
 .../streaming/state/RocksDBStateEncoder.scala      |  96 +---------
 .../state/RocksDBStateStoreProvider.scala          | 213 +++++++++++++--------
 .../execution/benchmark/StateStoreBenchmark.scala  |  25 ++-
 .../execution/streaming/state/RocksDBSuite.scala   |   8 +-
 6 files changed, 227 insertions(+), 322 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
index eed7827..105a446 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.sql.execution.streaming.state
 
 import java.io.File
-import java.util
 import java.util.Locale
 import javax.annotation.concurrent.GuardedBy
 
@@ -51,12 +50,9 @@ import org.apache.spark.util.{NextIterator, Utils}
  * @param hadoopConf   Hadoop configuration for talking to the remote file system
  * @param loggingId    Id that will be prepended in logs for isolating concurrent RocksDBs
  */
-// FIXME: optionally receiving column families
 class RocksDB(
     dfsRootDir: String,
     val conf: RocksDBConf,
-    // TODO: change "default" to constant
-    columnFamilies: Seq[String] = Seq("default"),
     localRootDir: File = Utils.createTempDir(),
     hadoopConf: Configuration = new Configuration,
     loggingId: String = "") extends Logging {
@@ -69,10 +65,16 @@ class RocksDB(
   private val flushOptions = new FlushOptions().setWaitForFlush(true)  // wait for flush to complete
   private val writeBatch = new WriteBatchWithIndex(true)  // overwrite multiple updates to a key
 
-  private val dbOptions: DBOptions = new DBOptions() // options to open the RocksDB
-  dbOptions.setCreateIfMissing(true)
-  dbOptions.setCreateMissingColumnFamilies(true)
+  private val bloomFilter = new BloomFilter()
+  private val tableFormatConfig = new BlockBasedTableConfig()
+  tableFormatConfig.setBlockSize(conf.blockSizeKB * 1024)
+  tableFormatConfig.setBlockCache(new LRUCache(conf.blockCacheSizeMB * 1024 * 1024))
+  tableFormatConfig.setFilterPolicy(bloomFilter)
+  tableFormatConfig.setFormatVersion(conf.formatVersion)
 
+  private val dbOptions = new Options() // options to open the RocksDB
+  dbOptions.setCreateIfMissing(true)
+  dbOptions.setTableFormatConfig(tableFormatConfig)
   private val dbLogger = createLogger() // for forwarding RocksDB native logs to log4j
   dbOptions.setStatistics(new Statistics())
   private val nativeStats = dbOptions.statistics()
@@ -85,18 +87,18 @@ class RocksDB(
   private val acquireLock = new Object
 
   @volatile private var db: NativeRocksDB = _
-  @volatile private var columnFamilyHandles: util.Map[String, ColumnFamilyHandle] = _
-  @volatile private var defaultColumnFamilyHandle: ColumnFamilyHandle = _
   @volatile private var loadedVersion = -1L   // -1 = nothing valid is loaded
   @volatile private var numKeysOnLoadedVersion = 0L
   @volatile private var numKeysOnWritingVersion = 0L
   @volatile private var fileManagerMetrics = RocksDBFileManagerMetrics.EMPTY_METRICS
+  @volatile private var customMetadataOnLoadedVersion: Map[String, String] = Map.empty
+  @volatile private var customMetadataOnWritingVersion: Map[String, String] = Map.empty
 
   @GuardedBy("acquireLock")
   @volatile private var acquiredThreadInfo: AcquiredThreadInfo = _
 
   private val prefixScanReuseIter =
-    new java.util.concurrent.ConcurrentHashMap[(Long, Int), RocksIterator]()
+    new java.util.concurrent.ConcurrentHashMap[Long, RocksIterator]()
 
   /**
    * Load the given version of data in a native RocksDB instance.
@@ -114,6 +116,8 @@ class RocksDB(
         openDB()
         numKeysOnWritingVersion = metadata.numKeys
         numKeysOnLoadedVersion = metadata.numKeys
+        customMetadataOnLoadedVersion = metadata.customMetadata
+        customMetadataOnWritingVersion = metadata.customMetadata
         loadedVersion = version
         fileManagerMetrics = fileManager.latestLoadCheckpointMetrics
       }
@@ -137,28 +141,7 @@ class RocksDB(
    * @note This will return the last written value even if it was uncommitted.
    */
   def get(key: Array[Byte]): Array[Byte] = {
-    get(defaultColumnFamilyHandle, key)
-  }
-
-  // FIXME: method doc
-  def get(cf: String, key: Array[Byte]): Array[Byte] = {
-    get(findColumnFamilyHandle(cf), key)
-  }
-
-  private def get(cfHandle: ColumnFamilyHandle, key: Array[Byte]): Array[Byte] = {
-    writeBatch.getFromBatchAndDB(db, cfHandle, readOptions, key)
-  }
-
-  def merge(key: Array[Byte], value: Array[Byte]): Unit = {
-    merge(defaultColumnFamilyHandle, key, value)
-  }
-
-  def merge(cf: String, key: Array[Byte], value: Array[Byte]): Unit = {
-    merge(findColumnFamilyHandle(cf), key, value)
-  }
-
-  private def merge(cfHandle: ColumnFamilyHandle, key: Array[Byte], value: Array[Byte]): Unit = {
-    writeBatch.merge(cfHandle, key, value)
+    writeBatch.getFromBatchAndDB(db, readOptions, key)
   }
 
   /**
@@ -166,20 +149,8 @@ class RocksDB(
    * @note This update is not committed to disk until commit() is called.
    */
   def put(key: Array[Byte], value: Array[Byte]): Array[Byte] = {
-    put(defaultColumnFamilyHandle, key, value)
-  }
-
-  // FIXME: method doc
-  def put(cf: String, key: Array[Byte], value: Array[Byte]): Array[Byte] = {
-    put(findColumnFamilyHandle(cf), key, value)
-  }
-
-  private def put(
-      cfHandle: ColumnFamilyHandle,
-      key: Array[Byte],
-      value: Array[Byte]): Array[Byte] = {
-    val oldValue = writeBatch.getFromBatchAndDB(db, cfHandle, readOptions, key)
-    writeBatch.put(cfHandle, key, value)
+    val oldValue = writeBatch.getFromBatchAndDB(db, readOptions, key)
+    writeBatch.put(key, value)
     if (oldValue == null) {
       numKeysOnWritingVersion += 1
     }
@@ -191,18 +162,9 @@ class RocksDB(
    * @note This update is not committed to disk until commit() is called.
    */
   def remove(key: Array[Byte]): Array[Byte] = {
-    remove(defaultColumnFamilyHandle, key)
-  }
-
-  // FIXME: method doc
-  def remove(cf: String, key: Array[Byte]): Array[Byte] = {
-    remove(findColumnFamilyHandle(cf), key)
-  }
-
-  private def remove(cfHandle: ColumnFamilyHandle, key: Array[Byte]): Array[Byte] = {
-    val value = writeBatch.getFromBatchAndDB(db, cfHandle, readOptions, key)
+    val value = writeBatch.getFromBatchAndDB(db, readOptions, key)
     if (value != null) {
-      writeBatch.delete(cfHandle, key)
+      writeBatch.remove(key)
       numKeysOnWritingVersion -= 1
     }
     value
@@ -211,17 +173,8 @@ class RocksDB(
   /**
    * Get an iterator of all committed and uncommitted key-value pairs.
    */
-  def iterator(): NextIterator[ByteArrayPair] = {
-    iterator(defaultColumnFamilyHandle)
-  }
-
-  // FIXME: doc
-  def iterator(cf: String): NextIterator[ByteArrayPair] = {
-    iterator(findColumnFamilyHandle(cf))
-  }
-
-  private def iterator(cfHandle: ColumnFamilyHandle): NextIterator[ByteArrayPair] = {
-    val iter = writeBatch.newIteratorWithBase(cfHandle, db.newIterator(cfHandle))
+  def iterator(): Iterator[ByteArrayPair] = {
+    val iter = writeBatch.newIteratorWithBase(db.newIterator())
     logInfo(s"Getting iterator from version $loadedVersion")
     iter.seekToFirst()
 
@@ -248,20 +201,11 @@ class RocksDB(
   }
 
   def prefixScan(prefix: Array[Byte]): Iterator[ByteArrayPair] = {
-    prefixScan(defaultColumnFamilyHandle, prefix)
-  }
-
-  def prefixScan(cf: String, prefix: Array[Byte]): Iterator[ByteArrayPair] = {
-    prefixScan(findColumnFamilyHandle(cf), prefix)
-  }
-
-  private def prefixScan(
-      cfHandle: ColumnFamilyHandle, prefix: Array[Byte]): Iterator[ByteArrayPair] = {
     val threadId = Thread.currentThread().getId
-    val iter = prefixScanReuseIter.computeIfAbsent((threadId, cfHandle.getID), key => {
-      val it = writeBatch.newIteratorWithBase(cfHandle, db.newIterator(cfHandle))
+    val iter = prefixScanReuseIter.computeIfAbsent(threadId, tid => {
+      val it = writeBatch.newIteratorWithBase(db.newIterator())
       logInfo(s"Getting iterator from version $loadedVersion for prefix scan on " +
-        s"thread ID ${key._1} and column family ID ${key._2}")
+        s"thread ID $tid")
       it
     })
 
@@ -283,12 +227,8 @@ class RocksDB(
     }
   }
 
-  private def findColumnFamilyHandle(cf: String): ColumnFamilyHandle = {
-    val cfHandle = columnFamilyHandles.get(cf)
-    if (cfHandle == null) {
-      throw new IllegalArgumentException(s"Handle for column family $cf is not found")
-    }
-    cfHandle
+  def setCustomMetadata(metadata: Map[String, String]): Unit = {
+    customMetadataOnWritingVersion = metadata
   }
 
   /**
@@ -310,16 +250,11 @@ class RocksDB(
       val writeTimeMs = timeTakenMs { db.write(writeOptions, writeBatch) }
 
       logInfo(s"Flushing updates for $newVersion")
-      val flushTimeMs = timeTakenMs {
-        db.flush(flushOptions,
-          new util.ArrayList[ColumnFamilyHandle](columnFamilyHandles.values()))
-      }
+      val flushTimeMs = timeTakenMs { db.flush(flushOptions) }
 
       val compactTimeMs = if (conf.compactOnCommit) {
         logInfo("Compacting")
-        timeTakenMs {
-          columnFamilyHandles.values().forEach(cfHandle => db.compactRange(cfHandle))
-        }
+        timeTakenMs { db.compactRange() }
       } else 0
 
       logInfo("Pausing background work")
@@ -335,9 +270,11 @@ class RocksDB(
 
       logInfo(s"Syncing checkpoint for $newVersion to DFS")
       val fileSyncTimeMs = timeTakenMs {
-        fileManager.saveCheckpointToDfs(checkpointDir, newVersion, numKeysOnWritingVersion)
+        fileManager.saveCheckpointToDfs(checkpointDir, newVersion, numKeysOnWritingVersion,
+          customMetadataOnWritingVersion.toMap)
       }
       numKeysOnLoadedVersion = numKeysOnWritingVersion
+      customMetadataOnLoadedVersion = customMetadataOnWritingVersion
       loadedVersion = newVersion
       fileManagerMetrics = fileManager.latestSaveCheckpointMetrics
       commitLatencyMs ++= Map(
@@ -352,7 +289,6 @@ class RocksDB(
       loadedVersion
     } catch {
       case t: Throwable =>
-        logWarning(s"ERROR! exc: $t", t)
         loadedVersion = -1  // invalidate loaded version
         throw t
     } finally {
@@ -369,6 +305,7 @@ class RocksDB(
     closePrefixScanIterators()
     writeBatch.clear()
     numKeysOnWritingVersion = numKeysOnLoadedVersion
+    customMetadataOnWritingVersion = customMetadataOnLoadedVersion
     release()
     logInfo(s"Rolled back to $loadedVersion")
   }
@@ -404,6 +341,8 @@ class RocksDB(
   /** Get the latest version available in the DFS */
   def getLatestVersion(): Long = fileManager.getLatestVersion()
 
+  def getCustomMetadata(): Map[String, String] = customMetadataOnWritingVersion
+
   /** Get current instantaneous statistics */
   def metrics: RocksDBMetrics = {
     import HistogramType._
@@ -496,43 +435,12 @@ class RocksDB(
 
   private def openDB(): Unit = {
     assert(db == null)
-
-    val columnFamilyDescriptors = new util.ArrayList[ColumnFamilyDescriptor]()
-    columnFamilies.foreach { cf =>
-      val bloomFilter = new BloomFilter()
-      val tableFormatConfig = new BlockBasedTableConfig()
-      tableFormatConfig.setBlockSize(conf.blockSizeKB * 1024)
-      tableFormatConfig.setBlockCache(new LRUCache(conf.blockCacheSizeMB * 1024 * 1024))
-      tableFormatConfig.setFilterPolicy(bloomFilter)
-      tableFormatConfig.setFormatVersion(conf.formatVersion)
-
-      val columnFamilyOptions = new ColumnFamilyOptions()
-      columnFamilyOptions.setTableFormatConfig(tableFormatConfig)
-      columnFamilyDescriptors.add(new ColumnFamilyDescriptor(cf.getBytes(), columnFamilyOptions))
-    }
-
-    val cfHandles = new util.ArrayList[ColumnFamilyHandle](columnFamilyDescriptors.size())
-    db = NativeRocksDB.open(dbOptions, workingDir.toString, columnFamilyDescriptors,
-      cfHandles)
-
-    columnFamilyHandles = new util.HashMap[String, ColumnFamilyHandle]()
-    columnFamilies.indices.foreach { idx =>
-      columnFamilyHandles.put(columnFamilies(idx), cfHandles.get(idx))
-    }
-
-    // FIXME: constant
-    defaultColumnFamilyHandle = columnFamilyHandles.get("default")
-
+    db = NativeRocksDB.open(dbOptions, workingDir.toString)
     logInfo(s"Opened DB with conf ${conf}")
   }
 
   private def closeDB(): Unit = {
     if (db != null) {
-      columnFamilyHandles.entrySet().forEach(pair => db.destroyColumnFamilyHandle(pair.getValue))
-      columnFamilyHandles.clear()
-      columnFamilyHandles = null
-      defaultColumnFamilyHandle = null
-
       db.close()
       db = null
     }
@@ -546,17 +454,10 @@ class RocksDB(
         // Warn is mapped to info because RocksDB warn is too verbose
         // (e.g. dumps non-warning stuff like stats)
         val loggingFunc: ( => String) => Unit = infoLogLevel match {
-          /*
           case InfoLogLevel.FATAL_LEVEL | InfoLogLevel.ERROR_LEVEL => logError(_)
           case InfoLogLevel.WARN_LEVEL | InfoLogLevel.INFO_LEVEL => logInfo(_)
           case InfoLogLevel.DEBUG_LEVEL => logDebug(_)
           case _ => logTrace(_)
-           */
-          case InfoLogLevel.FATAL_LEVEL | InfoLogLevel.ERROR_LEVEL => logError(_)
-          case InfoLogLevel.WARN_LEVEL => logWarning(_)
-          case InfoLogLevel.INFO_LEVEL => logInfo(_)
-          case InfoLogLevel.DEBUG_LEVEL => logDebug(_)
-          case _ => logTrace(_)
         }
         loggingFunc(s"[NativeRocksDB-${infoLogLevel.getValue}] $logMsg")
       }
@@ -702,7 +603,7 @@ object RocksDBMetrics {
 
 /** Class to wrap RocksDB's native histogram */
 case class RocksDBNativeHistogram(
-    sum: Long, avg: Double, stddev: Double, median: Double, p95: Double, p99: Double, count: Long) {
+  sum: Long, avg: Double, stddev: Double, median: Double, p95: Double, p99: Double, count: Long) {
   def json: String = Serialization.write(this)(RocksDBMetrics.format)
 }
 
@@ -733,4 +634,3 @@ case class AcquiredThreadInfo() {
     s"[ThreadId: ${threadRef.get.map(_.getId)}$taskStr]"
   }
 }
-
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
index 23cdbd0..567f916 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
@@ -152,11 +152,15 @@ class RocksDBFileManager(
   def latestSaveCheckpointMetrics: RocksDBFileManagerMetrics = saveCheckpointMetrics
 
   /** Save all the files in given local checkpoint directory as a committed version in DFS */
-  def saveCheckpointToDfs(checkpointDir: File, version: Long, numKeys: Long): Unit = {
+  def saveCheckpointToDfs(
+      checkpointDir: File,
+      version: Long,
+      numKeys: Long,
+      customMetadata: Map[String, String] = Map.empty): Unit = {
     logFilesInDir(checkpointDir, s"Saving checkpoint files for version $version")
     val (localImmutableFiles, localOtherFiles) = listRocksDBFiles(checkpointDir)
     val rocksDBFiles = saveImmutableFilesToDfs(version, localImmutableFiles)
-    val metadata = RocksDBCheckpointMetadata(rocksDBFiles, numKeys)
+    val metadata = RocksDBCheckpointMetadata(rocksDBFiles, numKeys, customMetadata)
     val metadataFile = localMetadataFile(checkpointDir)
     metadata.writeToFile(metadataFile)
     logInfo(s"Written metadata for version $version:\n${metadata.prettyJson}")
@@ -184,7 +188,7 @@ class RocksDBFileManager(
     val metadata = if (version == 0) {
       if (localDir.exists) Utils.deleteRecursively(localDir)
       localDir.mkdirs()
-      RocksDBCheckpointMetadata(Seq.empty, 0)
+      RocksDBCheckpointMetadata(Seq.empty, 0, Map.empty)
     } else {
       // Delete all non-immutable files in local dir, and unzip new ones from DFS commit file
       listRocksDBFiles(localDir)._2.foreach(_.delete())
@@ -540,12 +544,20 @@ object RocksDBFileManagerMetrics {
 case class RocksDBCheckpointMetadata(
     sstFiles: Seq[RocksDBSstFile],
     logFiles: Seq[RocksDBLogFile],
-    numKeys: Long) {
+    numKeys: Long,
+    customMetadata: Map[String, String]) {
   import RocksDBCheckpointMetadata._
 
   def json: String = {
-    // We turn this field into a null to avoid write a empty logFiles field in the json.
-    val nullified = if (logFiles.isEmpty) this.copy(logFiles = null) else this
+    // We turn the field into a null to avoid write below fields in the json if they are empty:
+    // - logFiles
+    // - customMetadata
+    val nullified = {
+      var cur = this
+      cur = if (logFiles.isEmpty) cur.copy(logFiles = null) else cur
+      cur = if (customMetadata.isEmpty) cur.copy(customMetadata = null) else cur
+      cur
+    }
     mapper.writeValueAsString(nullified)
   }
 
@@ -593,11 +605,18 @@ object RocksDBCheckpointMetadata {
     }
   }
 
-  def apply(rocksDBFiles: Seq[RocksDBImmutableFile], numKeys: Long): RocksDBCheckpointMetadata = {
+  def apply(
+      rocksDBFiles: Seq[RocksDBImmutableFile],
+      numKeys: Long): RocksDBCheckpointMetadata = apply(rocksDBFiles, numKeys, Map.empty)
+
+  def apply(
+      rocksDBFiles: Seq[RocksDBImmutableFile],
+      numKeys: Long,
+      customMetadata: Map[String, String]): RocksDBCheckpointMetadata = {
     val sstFiles = rocksDBFiles.collect { case file: RocksDBSstFile => file }
     val logFiles = rocksDBFiles.collect { case file: RocksDBLogFile => file }
 
-    RocksDBCheckpointMetadata(sstFiles, logFiles, numKeys)
+    RocksDBCheckpointMetadata(sstFiles, logFiles, numKeys, customMetadata)
   }
 }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
index 323826d..84e9a8d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
@@ -22,7 +22,7 @@ import java.nio.ByteOrder
 
 import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow, UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{STATE_ENCODING_NUM_VERSION_BYTES, STATE_ENCODING_VERSION}
-import org.apache.spark.sql.types.{StructField, StructType, TimestampType}
+import org.apache.spark.sql.types.{StructField, StructType}
 import org.apache.spark.unsafe.Platform
 
 sealed trait RocksDBStateEncoder {
@@ -30,11 +30,6 @@ sealed trait RocksDBStateEncoder {
   def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte]
   def extractPrefixKey(key: UnsafeRow): UnsafeRow
 
-  def supportEventTimeIndex: Boolean
-  def extractEventTime(key: UnsafeRow): Long
-  def encodeEventTimeIndexKey(timestamp: Long, encodedKey: Array[Byte]): Array[Byte]
-  def decodeEventTimeIndexKey(eventTimeBytes: Array[Byte]): (Long, Array[Byte])
-
   def encodeKey(row: UnsafeRow): Array[Byte]
   def encodeValue(row: UnsafeRow): Array[Byte]
 
@@ -47,13 +42,11 @@ object RocksDBStateEncoder {
   def getEncoder(
       keySchema: StructType,
       valueSchema: StructType,
-      numColsPrefixKey: Int,
-      eventTimeColIdx: Array[Int]): RocksDBStateEncoder = {
+      numColsPrefixKey: Int): RocksDBStateEncoder = {
     if (numColsPrefixKey > 0) {
-      // FIXME: need to deal with prefix case as well
       new PrefixKeyScanStateEncoder(keySchema, valueSchema, numColsPrefixKey)
     } else {
-      new NoPrefixKeyStateEncoder(keySchema, valueSchema, eventTimeColIdx)
+      new NoPrefixKeyStateEncoder(keySchema, valueSchema)
     }
   }
 
@@ -228,23 +221,6 @@ class PrefixKeyScanStateEncoder(
   }
 
   override def supportPrefixKeyScan: Boolean = true
-
-  override def supportEventTimeIndex: Boolean = false
-
-  // FIXME: fix me!
-  def extractEventTime(key: UnsafeRow): Long = {
-    throw new IllegalStateException("This encoder doesn't support event time index!")
-  }
-
-  // FIXME: fix me!
-  def encodeEventTimeIndexKey(timestamp: Long, encodedKey: Array[Byte]): Array[Byte] = {
-    throw new IllegalStateException("This encoder doesn't support event time index!")
-  }
-
-  // FIXME: fix me!
-  def decodeEventTimeIndexKey(eventTimeBytes: Array[Byte]): (Long, Array[Byte]) = {
-    throw new IllegalStateException("This encoder doesn't support event time index!")
-  }
 }
 
 /**
@@ -259,8 +235,7 @@ class PrefixKeyScanStateEncoder(
  */
 class NoPrefixKeyStateEncoder(
     keySchema: StructType,
-    valueSchema: StructType,
-    eventTimeColIdx: Array[Int]) extends RocksDBStateEncoder {
+    valueSchema: StructType) extends RocksDBStateEncoder {
 
   import RocksDBStateEncoder._
 
@@ -269,32 +244,6 @@ class NoPrefixKeyStateEncoder(
   private val valueRow = new UnsafeRow(valueSchema.size)
   private val rowTuple = new UnsafeRowPair()
 
-  validateColumnTypeOnEventTimeColumn()
-
-  private def validateColumnTypeOnEventTimeColumn(): Unit = {
-    if (eventTimeColIdx.nonEmpty) {
-      var curSchema: StructType = keySchema
-      eventTimeColIdx.dropRight(1).foreach { idx =>
-        curSchema(idx).dataType match {
-          case stType: StructType =>
-            curSchema = stType
-          case _ =>
-            // FIXME: better error message
-            throw new IllegalStateException("event time column is not properly specified! " +
-              s"index: ${eventTimeColIdx.mkString("(", ", ", ")")} / key schema: $keySchema")
-        }
-      }
-
-      curSchema(eventTimeColIdx.last).dataType match {
-        case _: TimestampType =>
-        case _ =>
-          // FIXME: better error message
-          throw new IllegalStateException("event time column is not properly specified! " +
-            s"index: ${eventTimeColIdx.mkString("(", ", ", ")")} / key schema: $keySchema")
-      }
-    }
-  }
-
   override def encodeKey(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)
 
   override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)
@@ -337,41 +286,4 @@ class NoPrefixKeyStateEncoder(
   override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = {
     throw new IllegalStateException("This encoder doesn't support prefix key!")
   }
-
-  override def supportEventTimeIndex: Boolean = eventTimeColIdx.nonEmpty
-
-  override def extractEventTime(key: UnsafeRow): Long = {
-    var curRow: UnsafeRow = key
-    var curSchema: StructType = keySchema
-
-    eventTimeColIdx.dropRight(1).foreach { idx =>
-      // validation is done in initialization phase
-      curSchema = curSchema(idx).dataType.asInstanceOf[StructType]
-      curRow = curRow.getStruct(idx, curSchema.length)
-    }
-
-    curRow.getLong(eventTimeColIdx.last) / 1000
-  }
-
-  override def encodeEventTimeIndexKey(timestamp: Long, encodedKey: Array[Byte]): Array[Byte] = {
-    val newKey = new Array[Byte](8 + encodedKey.length)
-
-    Platform.putLong(newKey, Platform.BYTE_ARRAY_OFFSET,
-      BinarySortable.encodeToBinarySortableLong(timestamp))
-    Platform.copyMemory(encodedKey, Platform.BYTE_ARRAY_OFFSET, newKey,
-      Platform.BYTE_ARRAY_OFFSET + 8, encodedKey.length)
-
-    newKey
-  }
-
-  override def decodeEventTimeIndexKey(eventTimeBytes: Array[Byte]): (Long, Array[Byte]) = {
-    val encoded = Platform.getLong(eventTimeBytes, Platform.BYTE_ARRAY_OFFSET)
-    val timestamp = BinarySortable.decodeBinarySortableLong(encoded)
-
-    val key = new Array[Byte](eventTimeBytes.length - 8)
-    Platform.copyMemory(eventTimeBytes, Platform.BYTE_ARRAY_OFFSET + 8,
-      key, Platform.BYTE_ARRAY_OFFSET, eventTimeBytes.length - 8)
-
-    (timestamp, key)
-  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
index 1d66220..d7c2e0f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
@@ -24,15 +24,14 @@ import org.apache.hadoop.conf.Configuration
 import org.apache.spark.{SparkConf, SparkEnv}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.types.StructType
-import org.apache.spark.unsafe.Platform
+import org.apache.spark.sql.types.{StructType, TimestampType}
 import org.apache.spark.util.{NextIterator, Utils}
 
 private[sql] class RocksDBStateStoreProvider
   extends StateStoreProvider with Logging with Closeable {
   import RocksDBStateStoreProvider._
 
-  class RocksDBStateStore(lastVersion: Long) extends StateStore {
+  class RocksDBStateStore(lastVersion: Long, eventTimeColIdx: Array[Int]) extends StateStore {
     /** Trait and classes representing the internal state of the store */
     trait STATE
     case object UPDATING extends STATE
@@ -42,6 +41,40 @@ private[sql] class RocksDBStateStoreProvider
     @volatile private var state: STATE = UPDATING
     @volatile private var isValidated = false
 
+    private val supportEventTimeIndex: Boolean = eventTimeColIdx.nonEmpty
+
+    if (supportEventTimeIndex) {
+      validateColumnTypeOnEventTimeColumn()
+    }
+
+    private def validateColumnTypeOnEventTimeColumn(): Unit = {
+      require(eventTimeColIdx.nonEmpty)
+
+      var curSchema: StructType = keySchema
+      eventTimeColIdx.dropRight(1).foreach { idx =>
+        curSchema(idx).dataType match {
+          case stType: StructType =>
+            curSchema = stType
+          case _ =>
+            // FIXME: better error message
+            throw new IllegalStateException("event time column is not properly specified! " +
+              s"index: ${eventTimeColIdx.mkString("(", ", ", ")")} / key schema: $keySchema")
+        }
+      }
+
+      curSchema(eventTimeColIdx.last).dataType match {
+        case _: TimestampType =>
+        case _ =>
+          // FIXME: better error message
+          throw new IllegalStateException("event time column is not properly specified! " +
+            s"index: ${eventTimeColIdx.mkString("(", ", ", ")")} / key schema: $keySchema")
+      }
+    }
+
+    private var lowestEventTime: Long = Option(rocksDB.getCustomMetadata())
+      .flatMap(_.get(METADATA_KEY_LOWEST_EVENT_TIME).map(_.toLong))
+      .getOrElse(INVALID_LOWEST_EVENT_TIME_VALUE)
+
     override def id: StateStoreId = RocksDBStateStoreProvider.this.stateStoreId
 
     override def version: Long = lastVersion
@@ -66,19 +99,29 @@ private[sql] class RocksDBStateStoreProvider
       val encodedValue = encoder.encodeValue(value)
       rocksDB.put(encodedKey, encodedValue)
 
-      if (encoder.supportEventTimeIndex) {
-        val timestamp = encoder.extractEventTime(key)
-        val tsKey = encoder.encodeEventTimeIndexKey(timestamp, encodedKey)
-
-        rocksDB.put(CF_EVENT_TIME_INDEX, tsKey, Array.empty)
+      if (supportEventTimeIndex) {
+        val eventTimeValue = extractEventTime(key)
+        if (lowestEventTime != INVALID_LOWEST_EVENT_TIME_VALUE
+          && lowestEventTime > eventTimeValue) {
+          lowestEventTime = eventTimeValue
+        }
       }
     }
 
     override def remove(key: UnsafeRow): Unit = {
       verify(state == UPDATING, "Cannot remove after already committed or aborted")
       verify(key != null, "Key cannot be null")
-      // FIXME: this should reflect the index
+
       rocksDB.remove(encoder.encodeKey(key))
+      if (supportEventTimeIndex) {
+        val eventTimeValue = extractEventTime(key)
+        if (lowestEventTime == eventTimeValue) {
+          // We can't track the next lowest event time without scanning entire keys.
+          // Mark the lowest event time value as invalid, so that scan happens in evict phase and
+          // the value is correctly updated later.
+          lowestEventTime = INVALID_LOWEST_EVENT_TIME_VALUE
+        }
+      }
     }
 
     override def iterator(): Iterator[UnsafeRowPair] = {
@@ -100,8 +143,89 @@ private[sql] class RocksDBStateStoreProvider
       rocksDB.prefixScan(prefix).map(kv => encoder.decode(kv))
     }
 
+    /** FIXME: method doc */
+    override def evictOnWatermark(
+      watermarkMs: Long,
+      altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair] = {
+      if (supportEventTimeIndex) {
+        // convert lowestEventTime to milliseconds, and compare to watermarkMs
+        // retract 1 ms to avoid edge-case on conversion from microseconds to milliseconds
+        if (lowestEventTime != INVALID_LOWEST_EVENT_TIME_VALUE
+          && ((lowestEventTime / 1000) - 1 > watermarkMs)) {
+          Iterator.empty
+        } else {
+          // start with invalidating the lowest event time
+          lowestEventTime = INVALID_LOWEST_EVENT_TIME_VALUE
+
+          new NextIterator[UnsafeRowPair] {
+            private val iter = rocksDB.iterator()
+
+            // here we use Long.MaxValue as invalid value
+            private var lowestEventTimeInIter = Long.MaxValue
+
+            override protected def getNext(): UnsafeRowPair = {
+              var result: UnsafeRowPair = null
+              while (result == null && iter.hasNext) {
+                val kv = iter.next()
+                val rowPair = encoder.decode(kv)
+                if (altPred(rowPair)) {
+                  rocksDB.remove(kv.key)
+                  result = rowPair
+                } else {
+                  val eventTime = extractEventTime(rowPair.key)
+                  if (lowestEventTimeInIter > eventTime) {
+                    lowestEventTimeInIter = eventTime
+                  }
+                }
+              }
+
+              if (result == null) {
+                finished = true
+                null
+              } else {
+                result
+              }
+            }
+
+            override protected def close(): Unit = {
+              if (lowestEventTimeInIter != Long.MaxValue) {
+                lowestEventTime = lowestEventTimeInIter
+              }
+            }
+          }
+        }
+      } else {
+        rocksDB.iterator().flatMap { kv =>
+          val rowPair = encoder.decode(kv)
+          if (altPred(rowPair)) {
+            rocksDB.remove(kv.key)
+            Some(rowPair)
+          } else {
+            None
+          }
+        }
+      }
+    }
+
+    private def extractEventTime(key: UnsafeRow): Long = {
+      var curRow: UnsafeRow = key
+      var curSchema: StructType = keySchema
+
+      eventTimeColIdx.dropRight(1).foreach { idx =>
+        // validation is done in initialization phase
+        curSchema = curSchema(idx).dataType.asInstanceOf[StructType]
+        curRow = curRow.getStruct(idx, curSchema.length)
+      }
+
+      curRow.getLong(eventTimeColIdx.last)
+    }
+
     override def commit(): Long = synchronized {
       verify(state == UPDATING, "Cannot commit after already committed or aborted")
+
+      // set the metadata to RocksDB instance so that it can be committed as well
+      rocksDB.setCustomMetadata(Map(METADATA_KEY_LOWEST_EVENT_TIME -> lowestEventTime.toString))
+
       val newVersion = rocksDB.commit()
       state = COMMITTED
       logInfo(s"Committed $newVersion for $id")
@@ -173,68 +297,6 @@ private[sql] class RocksDBStateStoreProvider
 
     /** Return the [[RocksDB]] instance in this store. This is exposed mainly for testing. */
     def dbInstance(): RocksDB = rocksDB
-
-    /** FIXME: method doc */
-    override def evictOnWatermark(
-        watermarkMs: Long,
-        altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair] = {
-
-      if (encoder.supportEventTimeIndex) {
-        val kv = new ByteArrayPair()
-
-        // FIXME: DEBUG
-        // logWarning(s"DEBUG: start iterating event time index, watermark $watermarkMs")
-
-        new NextIterator[UnsafeRowPair] {
-          private val iter = rocksDB.iterator(CF_EVENT_TIME_INDEX)
-          override protected def getNext(): UnsafeRowPair = {
-            if (iter.hasNext) {
-              val pair = iter.next()
-
-              val encodedTs = Platform.getLong(pair.key, Platform.BYTE_ARRAY_OFFSET)
-              val decodedTs = BinarySortable.decodeBinarySortableLong(encodedTs)
-
-              // FIXME: DEBUG
-              // logWarning(s"DEBUG: decoded TS: $decodedTs")
-
-              if (decodedTs > watermarkMs) {
-                finished = true
-                null
-              } else {
-                // FIXME: can we leverage deleteRange to bulk delete on index?
-                rocksDB.remove(CF_EVENT_TIME_INDEX, pair.key)
-                val (_, encodedKey) = encoder.decodeEventTimeIndexKey(pair.key)
-                val value = rocksDB.get(encodedKey)
-                if (value == null) {
-                  throw new IllegalStateException("Event time index has been broken!")
-                }
-                kv.set(encodedKey, value)
-                val rowPair = encoder.decode(kv)
-                rocksDB.remove(encodedKey)
-                rowPair
-              }
-            } else {
-              finished = true
-              null
-            }
-          }
-
-          override protected def close(): Unit = {
-            iter.closeIfNeeded()
-          }
-        }
-      } else {
-        rocksDB.iterator().flatMap { kv =>
-          val rowPair = encoder.decode(kv)
-          if (altPred(rowPair)) {
-            rocksDB.remove(kv.key)
-            Some(rowPair)
-          } else {
-            None
-          }
-        }
-      }
-    }
   }
 
   override def init(
@@ -257,7 +319,7 @@ private[sql] class RocksDBStateStoreProvider
     this.operatorContext = operatorContext
 
     this.encoder = RocksDBStateEncoder.getEncoder(keySchema, valueSchema,
-      operatorContext.numColsPrefixKey, operatorContext.eventTimeColIdx)
+      operatorContext.numColsPrefixKey)
 
     rocksDB // lazy initialization
   }
@@ -267,7 +329,7 @@ private[sql] class RocksDBStateStoreProvider
   override def getStore(version: Long): StateStore = {
     require(version >= 0, "Version cannot be less than 0")
     rocksDB.load(version)
-    new RocksDBStateStore(version)
+    new RocksDBStateStore(version, operatorContext.eventTimeColIdx)
   }
 
   override def doMaintenance(): Unit = {
@@ -298,7 +360,6 @@ private[sql] class RocksDBStateStoreProvider
     val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
     val localRootDir = Utils.createTempDir(Utils.getLocalDir(sparkConf), storeIdStr)
     new RocksDB(dfsRootDir, RocksDBConf(storeConf),
-      columnFamilies = Seq("default", RocksDBStateStoreProvider.CF_EVENT_TIME_INDEX),
       localRootDir = localRootDir,
       hadoopConf = hadoopConf, loggingId = storeIdStr)
   }
@@ -315,8 +376,8 @@ object RocksDBStateStoreProvider {
   val STATE_ENCODING_NUM_VERSION_BYTES = 1
   val STATE_ENCODING_VERSION: Byte = 0
 
-  // reserved column families
-  val CF_EVENT_TIME_INDEX: String = "__event_time_idx"
+  val INVALID_LOWEST_EVENT_TIME_VALUE: Long = Long.MinValue
+  val METADATA_KEY_LOWEST_EVENT_TIME: String = "lowestEventTimeInState"
 
   // Native operation latencies report as latency in microseconds
   // as SQLMetrics support millis. Convert the value to millis
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBenchmark.scala
index 9a83f5c..12ac517 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBenchmark.scala
@@ -66,7 +66,7 @@ object StateStoreBenchmark extends SqlBasedBenchmark {
   private def runEvictBenchmark(): Unit = {
     runBenchmark("evict rows") {
       val numOfRows = Seq(10000) // Seq(1000, 10000, 100000)
-      val numOfTimestamps = Seq(10, 100, 1000)
+      val numOfTimestamps = Seq(100, 1000)
       val numOfEvictionRates = Seq(50, 25, 10, 5, 1, 0) // Seq(100, 75, 50, 25, 1, 0)
 
       numOfRows.foreach { numOfRow =>
@@ -116,7 +116,7 @@ object StateStoreBenchmark extends SqlBasedBenchmark {
               val rocksDBStore = rocksDBProvider.getStore(committedVersion)
 
               timer.startTiming()
-              evictAsFullScanAndRemove(rocksDBStore, maxTimestampToEvictInMillis)
+              evictAsFullScanAndRemove(rocksDBStore, maxTimestampToEvictInMillis, numOfRowsToEvict)
               timer.stopTiming()
 
               rocksDBStore.abort()
@@ -126,7 +126,7 @@ object StateStoreBenchmark extends SqlBasedBenchmark {
               val rocksDBStore = rocksDBProvider.getStore(committedVersion)
 
               timer.startTiming()
-              evictAsNewEvictApi(rocksDBStore, maxTimestampToEvictInMillis)
+              evictAsNewEvictApi(rocksDBStore, maxTimestampToEvictInMillis, numOfRowsToEvict)
               timer.stopTiming()
 
               rocksDBStore.abort()
@@ -487,20 +487,29 @@ object StateStoreBenchmark extends SqlBasedBenchmark {
 
   private def evictAsFullScanAndRemove(
       store: StateStore,
-      maxTimestampToEvict: Long): Unit = {
+      maxTimestampToEvict: Long,
+      expectedNumOfRows: Long): Unit = {
+    var removedRows: Long = 0
     store.iterator().foreach { r =>
-      if (r.key.getLong(1) < maxTimestampToEvict) {
+      if (r.key.getLong(1) / 1000 <= maxTimestampToEvict) {
         store.remove(r.key)
+        removedRows += 1
       }
     }
+    assert(removedRows == expectedNumOfRows,
+      s"expected: $expectedNumOfRows actual: $removedRows")
   }
 
   private def evictAsNewEvictApi(
       store: StateStore,
-      maxTimestampToEvict: Long): Unit = {
+      maxTimestampToEvict: Long,
+      expectedNumOfRows: Long): Unit = {
+    var removedRows: Long = 0
     store.evictOnWatermark(maxTimestampToEvict, pair => {
-      pair.key.getLong(1) < maxTimestampToEvict
-    }).foreach { _ => }
+      pair.key.getLong(1) / 1000 <= maxTimestampToEvict
+    }).foreach { _ => removedRows += 1 }
+    assert(removedRows == expectedNumOfRows,
+      s"expected: $expectedNumOfRows actual: $removedRows")
   }
 
   private def fullScanAndCompareTimestamp(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
index 1ee2748..31e49ae 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
@@ -355,14 +355,18 @@ class RocksDBSuite extends SparkFunSuite {
       RocksDBCheckpointMetadata(Seq.empty, 0L),
       """{"sstFiles":[],"numKeys":0}"""
     )
-    // shouldn't include the "logFiles" field in json when it's empty
+    // shouldn't include the "logFiles" & "customMetadata" field in json when it's empty
     checkJsonRoundtrip(
       RocksDBCheckpointMetadata(sstFiles, 12345678901234L),
       """{"sstFiles":[{"localFileName":"00001.sst","dfsSstFileName":"00001-uuid.sst","sizeBytes":12345678901234}],"numKeys":12345678901234}"""
     )
+    // shouldn't include the "customMetadata" field in json when it's empty
     checkJsonRoundtrip(
-      RocksDBCheckpointMetadata(sstFiles, logFiles, 12345678901234L),
+      RocksDBCheckpointMetadata(sstFiles, logFiles, 12345678901234L, Map.empty),
       """{"sstFiles":[{"localFileName":"00001.sst","dfsSstFileName":"00001-uuid.sst","sizeBytes":12345678901234}],"logFiles":[{"localFileName":"00001.log","dfsLogFileName":"00001-uuid.log","sizeBytes":12345678901234}],"numKeys":12345678901234}""")
+
+    // FIXME: test customMetadata here
+
     // scalastyle:on line.size.limit
   }
 

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org