You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ka...@apache.org on 2021/11/05 07:42:58 UTC
[spark] 01/02: WIP: benchmark test code done
This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch WIP-optimize-eviction-in-rocksdb-state-store
in repository https://gitbox.apache.org/repos/asf/spark.git
commit 21d4f96c6c7a56acce30d485a0747e1d62d7ad27
Author: Jungtaek Lim <ka...@gmail.com>
AuthorDate: Thu Sep 30 11:53:39 2021 +0900
WIP: benchmark test code done
---
.../streaming/FlatMapGroupsWithStateExec.scala | 7 +-
.../state/HDFSBackedStateStoreProvider.scala | 42 +-
.../sql/execution/streaming/state/RocksDB.scala | 158 ++++-
.../streaming/state/RocksDBStateEncoder.scala | 135 ++++-
.../state/RocksDBStateStoreProvider.scala | 100 +++-
.../sql/execution/streaming/state/StateStore.scala | 40 +-
.../execution/streaming/state/StateStoreRDD.scala | 8 +-
.../state/StreamingAggregationStateManager.scala | 23 +
.../state/SymmetricHashJoinStateManager.scala | 3 +-
.../sql/execution/streaming/state/package.scala | 12 +-
.../execution/streaming/statefulOperators.scala | 62 +-
.../sql/execution/streaming/streamingLimits.scala | 5 +-
.../execution/benchmark/StateStoreBenchmark.scala | 633 +++++++++++++++++++++
...ngSortWithSessionWindowStateIteratorSuite.scala | 7 +-
.../streaming/state/MemoryStateStore.scala | 14 +
.../state/RocksDBStateStoreIntegrationSuite.scala | 60 +-
.../streaming/state/RocksDBStateStoreSuite.scala | 6 +-
.../streaming/state/StateStoreRDDSuite.scala | 18 +-
.../streaming/state/StateStoreSuite.scala | 31 +-
.../StreamingSessionWindowStateManagerSuite.scala | 4 +-
.../apache/spark/sql/streaming/StreamSuite.scala | 4 +-
.../sql/streaming/StreamingAggregationSuite.scala | 34 +-
22 files changed, 1270 insertions(+), 136 deletions(-)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index a00a622..381aeb9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -224,21 +224,23 @@ case class FlatMapGroupsWithStateExec(
val stateStoreId = StateStoreId(
stateInfo.get.checkpointLocation, stateInfo.get.operatorId, partitionId)
val storeProviderId = StateStoreProviderId(stateStoreId, stateInfo.get.queryRunId)
+ // FIXME: would setting prefixScan / evict help?
val store = StateStore.get(
storeProviderId,
groupingAttributes.toStructType,
stateManager.stateSchema,
- numColsPrefixKey = 0,
+ StatefulOperatorContext(),
stateInfo.get.storeVersion, storeConf, hadoopConfBroadcast.value.value)
val processor = new InputProcessor(store)
processDataWithPartition(childDataIterator, store, processor, Some(initStateIterator))
}
} else {
+ // FIXME: would setting prefixScan / evict help?
child.execute().mapPartitionsWithStateStore[InternalRow](
getStateInfo,
groupingAttributes.toStructType,
stateManager.stateSchema,
- numColsPrefixKey = 0,
+ StatefulOperatorContext(),
session.sqlContext.sessionState,
Some(session.sqlContext.streams.stateStoreCoordinator)
) { case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
@@ -334,6 +336,7 @@ case class FlatMapGroupsWithStateExec(
throw new IllegalStateException(
s"Cannot filter timed out keys for $timeoutConf")
}
+ // FIXME: would setting prefixScan / evict help?
val timingOutPairs = stateManager.getAllState(store).filter { state =>
state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index 75b7dae..96ba2a3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -100,8 +100,11 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
/** Trait and classes representing the internal state of the store */
trait STATE
+
case object UPDATING extends STATE
+
case object COMMITTED extends STATE
+
case object ABORTED extends STATE
private val newVersion = version + 1
@@ -195,6 +198,22 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
override def toString(): String = {
s"HDFSStateStore[id=(op=${id.operatorId},part=${id.partitionId}),dir=$baseDir]"
}
+
+ /** FIXME: method doc */
+ override def evictOnWatermark(
+ watermarkMs: Long,
+ altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair] = {
+ // HDFSBackedStateStore doesn't index event time column
+ // FIXME: should we do this for in-memory as well?
+ iterator().filter { pair =>
+ if (altPred.apply(pair)) {
+ remove(pair.key)
+ true
+ } else {
+ false
+ }
+ }
+ }
}
def getMetricsForProvider(): Map[String, Long] = synchronized {
@@ -219,7 +238,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
private def getLoadedMapForStore(version: Long): HDFSBackedStateStoreMap = synchronized {
require(version >= 0, "Version cannot be less than 0")
- val newMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
+ val newMap = HDFSBackedStateStoreMap.create(keySchema, operatorContext.numColsPrefixKey)
if (version > 0) {
newMap.putAll(loadMap(version))
}
@@ -230,7 +249,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
stateStoreId: StateStoreId,
keySchema: StructType,
valueSchema: StructType,
- numColsPrefixKey: Int,
+ operatorContext: StatefulOperatorContext,
storeConf: StateStoreConf,
hadoopConf: Configuration): Unit = {
this.stateStoreId_ = stateStoreId
@@ -240,10 +259,11 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
this.hadoopConf = hadoopConf
this.numberOfVersionsToRetainInMemory = storeConf.maxVersionsToRetainInMemory
- require((keySchema.length == 0 && numColsPrefixKey == 0) ||
- (keySchema.length > numColsPrefixKey), "The number of columns in the key must be " +
- "greater than the number of columns for prefix key!")
- this.numColsPrefixKey = numColsPrefixKey
+ require((keySchema.length == 0 && operatorContext.numColsPrefixKey == 0) ||
+ (keySchema.length > operatorContext.numColsPrefixKey), "The number of columns in the key " +
+ "must be greater than the number of columns for prefix key!")
+
+ this.operatorContext = operatorContext
fm.mkdirs(baseDir)
}
@@ -283,7 +303,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
@volatile private var storeConf: StateStoreConf = _
@volatile private var hadoopConf: Configuration = _
@volatile private var numberOfVersionsToRetainInMemory: Int = _
- @volatile private var numColsPrefixKey: Int = 0
+ @volatile private var operatorContext: StatefulOperatorContext = _
// TODO: The validation should be moved to a higher level so that it works for all state store
// implementations
@@ -401,7 +421,8 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
if (lastAvailableVersion <= 0) {
// Use an empty map for versions 0 or less.
- lastAvailableMap = Some(HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey))
+ lastAvailableMap = Some(HDFSBackedStateStoreMap.create(keySchema,
+ operatorContext.numColsPrefixKey))
} else {
lastAvailableMap =
synchronized { Option(loadedMaps.get(lastAvailableVersion)) }
@@ -411,7 +432,8 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
// Load all the deltas from the version after the last available one up to the target version.
// The last available version is the one with a full snapshot, so it doesn't need deltas.
- val resultMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
+ val resultMap = HDFSBackedStateStoreMap.create(keySchema,
+ operatorContext.numColsPrefixKey)
resultMap.putAll(lastAvailableMap.get)
for (deltaVersion <- lastAvailableVersion + 1 to version) {
updateFromDeltaFile(deltaVersion, resultMap)
@@ -554,7 +576,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
private def readSnapshotFile(version: Long): Option[HDFSBackedStateStoreMap] = {
val fileToRead = snapshotFile(version)
- val map = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
+ val map = HDFSBackedStateStoreMap.create(keySchema, operatorContext.numColsPrefixKey)
var input: DataInputStream = null
try {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
index 1ff8b41..eed7827 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.streaming.state
import java.io.File
+import java.util
import java.util.Locale
import javax.annotation.concurrent.GuardedBy
@@ -50,9 +51,12 @@ import org.apache.spark.util.{NextIterator, Utils}
* @param hadoopConf Hadoop configuration for talking to the remote file system
* @param loggingId Id that will be prepended in logs for isolating concurrent RocksDBs
*/
+// FIXME: optionally receiving column families
class RocksDB(
dfsRootDir: String,
val conf: RocksDBConf,
+ // TODO: change "default" to constant
+ columnFamilies: Seq[String] = Seq("default"),
localRootDir: File = Utils.createTempDir(),
hadoopConf: Configuration = new Configuration,
loggingId: String = "") extends Logging {
@@ -65,16 +69,10 @@ class RocksDB(
private val flushOptions = new FlushOptions().setWaitForFlush(true) // wait for flush to complete
private val writeBatch = new WriteBatchWithIndex(true) // overwrite multiple updates to a key
- private val bloomFilter = new BloomFilter()
- private val tableFormatConfig = new BlockBasedTableConfig()
- tableFormatConfig.setBlockSize(conf.blockSizeKB * 1024)
- tableFormatConfig.setBlockCache(new LRUCache(conf.blockCacheSizeMB * 1024 * 1024))
- tableFormatConfig.setFilterPolicy(bloomFilter)
- tableFormatConfig.setFormatVersion(conf.formatVersion)
-
- private val dbOptions = new Options() // options to open the RocksDB
+ private val dbOptions: DBOptions = new DBOptions() // options to open the RocksDB
dbOptions.setCreateIfMissing(true)
- dbOptions.setTableFormatConfig(tableFormatConfig)
+ dbOptions.setCreateMissingColumnFamilies(true)
+
private val dbLogger = createLogger() // for forwarding RocksDB native logs to log4j
dbOptions.setStatistics(new Statistics())
private val nativeStats = dbOptions.statistics()
@@ -87,6 +85,8 @@ class RocksDB(
private val acquireLock = new Object
@volatile private var db: NativeRocksDB = _
+ @volatile private var columnFamilyHandles: util.Map[String, ColumnFamilyHandle] = _
+ @volatile private var defaultColumnFamilyHandle: ColumnFamilyHandle = _
@volatile private var loadedVersion = -1L // -1 = nothing valid is loaded
@volatile private var numKeysOnLoadedVersion = 0L
@volatile private var numKeysOnWritingVersion = 0L
@@ -96,7 +96,7 @@ class RocksDB(
@volatile private var acquiredThreadInfo: AcquiredThreadInfo = _
private val prefixScanReuseIter =
- new java.util.concurrent.ConcurrentHashMap[Long, RocksIterator]()
+ new java.util.concurrent.ConcurrentHashMap[(Long, Int), RocksIterator]()
/**
* Load the given version of data in a native RocksDB instance.
@@ -137,7 +137,28 @@ class RocksDB(
* @note This will return the last written value even if it was uncommitted.
*/
def get(key: Array[Byte]): Array[Byte] = {
- writeBatch.getFromBatchAndDB(db, readOptions, key)
+ get(defaultColumnFamilyHandle, key)
+ }
+
+ // FIXME: method doc
+ def get(cf: String, key: Array[Byte]): Array[Byte] = {
+ get(findColumnFamilyHandle(cf), key)
+ }
+
+ private def get(cfHandle: ColumnFamilyHandle, key: Array[Byte]): Array[Byte] = {
+ writeBatch.getFromBatchAndDB(db, cfHandle, readOptions, key)
+ }
+
+ def merge(key: Array[Byte], value: Array[Byte]): Unit = {
+ merge(defaultColumnFamilyHandle, key, value)
+ }
+
+ def merge(cf: String, key: Array[Byte], value: Array[Byte]): Unit = {
+ merge(findColumnFamilyHandle(cf), key, value)
+ }
+
+ private def merge(cfHandle: ColumnFamilyHandle, key: Array[Byte], value: Array[Byte]): Unit = {
+ writeBatch.merge(cfHandle, key, value)
}
/**
@@ -145,8 +166,20 @@ class RocksDB(
* @note This update is not committed to disk until commit() is called.
*/
def put(key: Array[Byte], value: Array[Byte]): Array[Byte] = {
- val oldValue = writeBatch.getFromBatchAndDB(db, readOptions, key)
- writeBatch.put(key, value)
+ put(defaultColumnFamilyHandle, key, value)
+ }
+
+ // FIXME: method doc
+ def put(cf: String, key: Array[Byte], value: Array[Byte]): Array[Byte] = {
+ put(findColumnFamilyHandle(cf), key, value)
+ }
+
+ private def put(
+ cfHandle: ColumnFamilyHandle,
+ key: Array[Byte],
+ value: Array[Byte]): Array[Byte] = {
+ val oldValue = writeBatch.getFromBatchAndDB(db, cfHandle, readOptions, key)
+ writeBatch.put(cfHandle, key, value)
if (oldValue == null) {
numKeysOnWritingVersion += 1
}
@@ -158,9 +191,18 @@ class RocksDB(
* @note This update is not committed to disk until commit() is called.
*/
def remove(key: Array[Byte]): Array[Byte] = {
- val value = writeBatch.getFromBatchAndDB(db, readOptions, key)
+ remove(defaultColumnFamilyHandle, key)
+ }
+
+ // FIXME: method doc
+ def remove(cf: String, key: Array[Byte]): Array[Byte] = {
+ remove(findColumnFamilyHandle(cf), key)
+ }
+
+ private def remove(cfHandle: ColumnFamilyHandle, key: Array[Byte]): Array[Byte] = {
+ val value = writeBatch.getFromBatchAndDB(db, cfHandle, readOptions, key)
if (value != null) {
- writeBatch.remove(key)
+ writeBatch.delete(cfHandle, key)
numKeysOnWritingVersion -= 1
}
value
@@ -169,8 +211,17 @@ class RocksDB(
/**
* Get an iterator of all committed and uncommitted key-value pairs.
*/
- def iterator(): Iterator[ByteArrayPair] = {
- val iter = writeBatch.newIteratorWithBase(db.newIterator())
+ def iterator(): NextIterator[ByteArrayPair] = {
+ iterator(defaultColumnFamilyHandle)
+ }
+
+ // FIXME: doc
+ def iterator(cf: String): NextIterator[ByteArrayPair] = {
+ iterator(findColumnFamilyHandle(cf))
+ }
+
+ private def iterator(cfHandle: ColumnFamilyHandle): NextIterator[ByteArrayPair] = {
+ val iter = writeBatch.newIteratorWithBase(cfHandle, db.newIterator(cfHandle))
logInfo(s"Getting iterator from version $loadedVersion")
iter.seekToFirst()
@@ -197,11 +248,20 @@ class RocksDB(
}
def prefixScan(prefix: Array[Byte]): Iterator[ByteArrayPair] = {
+ prefixScan(defaultColumnFamilyHandle, prefix)
+ }
+
+ def prefixScan(cf: String, prefix: Array[Byte]): Iterator[ByteArrayPair] = {
+ prefixScan(findColumnFamilyHandle(cf), prefix)
+ }
+
+ private def prefixScan(
+ cfHandle: ColumnFamilyHandle, prefix: Array[Byte]): Iterator[ByteArrayPair] = {
val threadId = Thread.currentThread().getId
- val iter = prefixScanReuseIter.computeIfAbsent(threadId, tid => {
- val it = writeBatch.newIteratorWithBase(db.newIterator())
+ val iter = prefixScanReuseIter.computeIfAbsent((threadId, cfHandle.getID), key => {
+ val it = writeBatch.newIteratorWithBase(cfHandle, db.newIterator(cfHandle))
logInfo(s"Getting iterator from version $loadedVersion for prefix scan on " +
- s"thread ID $tid")
+ s"thread ID ${key._1} and column family ID ${key._2}")
it
})
@@ -223,6 +283,14 @@ class RocksDB(
}
}
+ private def findColumnFamilyHandle(cf: String): ColumnFamilyHandle = {
+ val cfHandle = columnFamilyHandles.get(cf)
+ if (cfHandle == null) {
+ throw new IllegalArgumentException(s"Handle for column family $cf is not found")
+ }
+ cfHandle
+ }
+
/**
* Commit all the updates made as a version to DFS. The steps it needs to do to commits are:
* - Write all the updates to the native RocksDB
@@ -242,11 +310,16 @@ class RocksDB(
val writeTimeMs = timeTakenMs { db.write(writeOptions, writeBatch) }
logInfo(s"Flushing updates for $newVersion")
- val flushTimeMs = timeTakenMs { db.flush(flushOptions) }
+ val flushTimeMs = timeTakenMs {
+ db.flush(flushOptions,
+ new util.ArrayList[ColumnFamilyHandle](columnFamilyHandles.values()))
+ }
val compactTimeMs = if (conf.compactOnCommit) {
logInfo("Compacting")
- timeTakenMs { db.compactRange() }
+ timeTakenMs {
+ columnFamilyHandles.values().forEach(cfHandle => db.compactRange(cfHandle))
+ }
} else 0
logInfo("Pausing background work")
@@ -279,6 +352,7 @@ class RocksDB(
loadedVersion
} catch {
case t: Throwable =>
+ logWarning(s"ERROR! exc: $t", t)
loadedVersion = -1 // invalidate loaded version
throw t
} finally {
@@ -422,12 +496,43 @@ class RocksDB(
private def openDB(): Unit = {
assert(db == null)
- db = NativeRocksDB.open(dbOptions, workingDir.toString)
+
+ val columnFamilyDescriptors = new util.ArrayList[ColumnFamilyDescriptor]()
+ columnFamilies.foreach { cf =>
+ val bloomFilter = new BloomFilter()
+ val tableFormatConfig = new BlockBasedTableConfig()
+ tableFormatConfig.setBlockSize(conf.blockSizeKB * 1024)
+ tableFormatConfig.setBlockCache(new LRUCache(conf.blockCacheSizeMB * 1024 * 1024))
+ tableFormatConfig.setFilterPolicy(bloomFilter)
+ tableFormatConfig.setFormatVersion(conf.formatVersion)
+
+ val columnFamilyOptions = new ColumnFamilyOptions()
+ columnFamilyOptions.setTableFormatConfig(tableFormatConfig)
+ columnFamilyDescriptors.add(new ColumnFamilyDescriptor(cf.getBytes(), columnFamilyOptions))
+ }
+
+ val cfHandles = new util.ArrayList[ColumnFamilyHandle](columnFamilyDescriptors.size())
+ db = NativeRocksDB.open(dbOptions, workingDir.toString, columnFamilyDescriptors,
+ cfHandles)
+
+ columnFamilyHandles = new util.HashMap[String, ColumnFamilyHandle]()
+ columnFamilies.indices.foreach { idx =>
+ columnFamilyHandles.put(columnFamilies(idx), cfHandles.get(idx))
+ }
+
+ // FIXME: constant
+ defaultColumnFamilyHandle = columnFamilyHandles.get("default")
+
logInfo(s"Opened DB with conf ${conf}")
}
private def closeDB(): Unit = {
if (db != null) {
+ columnFamilyHandles.entrySet().forEach(pair => db.destroyColumnFamilyHandle(pair.getValue))
+ columnFamilyHandles.clear()
+ columnFamilyHandles = null
+ defaultColumnFamilyHandle = null
+
db.close()
db = null
}
@@ -441,10 +546,17 @@ class RocksDB(
// Warn is mapped to info because RocksDB warn is too verbose
// (e.g. dumps non-warning stuff like stats)
val loggingFunc: ( => String) => Unit = infoLogLevel match {
+ /*
case InfoLogLevel.FATAL_LEVEL | InfoLogLevel.ERROR_LEVEL => logError(_)
case InfoLogLevel.WARN_LEVEL | InfoLogLevel.INFO_LEVEL => logInfo(_)
case InfoLogLevel.DEBUG_LEVEL => logDebug(_)
case _ => logTrace(_)
+ */
+ case InfoLogLevel.FATAL_LEVEL | InfoLogLevel.ERROR_LEVEL => logError(_)
+ case InfoLogLevel.WARN_LEVEL => logWarning(_)
+ case InfoLogLevel.INFO_LEVEL => logInfo(_)
+ case InfoLogLevel.DEBUG_LEVEL => logDebug(_)
+ case _ => logTrace(_)
}
loggingFunc(s"[NativeRocksDB-${infoLogLevel.getValue}] $logMsg")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
index 81755e5..323826d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
@@ -17,9 +17,12 @@
package org.apache.spark.sql.execution.streaming.state
+import java.lang.{Long => JLong}
+import java.nio.ByteOrder
+
import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{STATE_ENCODING_NUM_VERSION_BYTES, STATE_ENCODING_VERSION}
-import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.sql.types.{StructField, StructType, TimestampType}
import org.apache.spark.unsafe.Platform
sealed trait RocksDBStateEncoder {
@@ -27,6 +30,11 @@ sealed trait RocksDBStateEncoder {
def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte]
def extractPrefixKey(key: UnsafeRow): UnsafeRow
+ def supportEventTimeIndex: Boolean
+ def extractEventTime(key: UnsafeRow): Long
+ def encodeEventTimeIndexKey(timestamp: Long, encodedKey: Array[Byte]): Array[Byte]
+ def decodeEventTimeIndexKey(eventTimeBytes: Array[Byte]): (Long, Array[Byte])
+
def encodeKey(row: UnsafeRow): Array[Byte]
def encodeValue(row: UnsafeRow): Array[Byte]
@@ -39,11 +47,13 @@ object RocksDBStateEncoder {
def getEncoder(
keySchema: StructType,
valueSchema: StructType,
- numColsPrefixKey: Int): RocksDBStateEncoder = {
+ numColsPrefixKey: Int,
+ eventTimeColIdx: Array[Int]): RocksDBStateEncoder = {
if (numColsPrefixKey > 0) {
+ // FIXME: need to deal with prefix case as well
new PrefixKeyScanStateEncoder(keySchema, valueSchema, numColsPrefixKey)
} else {
- new NoPrefixKeyStateEncoder(keySchema, valueSchema)
+ new NoPrefixKeyStateEncoder(keySchema, valueSchema, eventTimeColIdx)
}
}
@@ -86,6 +96,39 @@ object RocksDBStateEncoder {
}
}
+object BinarySortable {
+ private val NATIVE_BYTE_ORDER: ByteOrder = ByteOrder.nativeOrder()
+ private val LITTLE_ENDIAN: Boolean = NATIVE_BYTE_ORDER == ByteOrder.LITTLE_ENDIAN
+ private val SIGN_BIT_LONG: Long = (1L << 63)
+
+ def encodeToBinarySortableLong(value: Long): Long = {
+ // Flip the sign bit. This simply works with binary form of comparison, as negative values
+ // are placed in reversed order, and positive values are placed in sequential order.
+ val encoded = value ^ SIGN_BIT_LONG
+
+ // We have to retain the sequence of bytes as same as BIG_ENDIAN, as the binary form will be
+ // compared in sequential order (via offset).
+ if (LITTLE_ENDIAN) {
+ JLong.reverseBytes(encoded)
+ } else {
+ encoded
+ }
+ }
+
+ def decodeBinarySortableLong(encoded: Long): Long = {
+ // The value is based on BIG_ENDIAN. If the system is LITTLE_ENDIAN, we should convert it to
+ // follow the system.
+ val decoded = if (LITTLE_ENDIAN) {
+ JLong.reverseBytes(encoded)
+ } else {
+ encoded
+ }
+
+ // Flip the sign bit as encode function does.
+ decoded ^ SIGN_BIT_LONG
+ }
+}
+
class PrefixKeyScanStateEncoder(
keySchema: StructType,
valueSchema: StructType,
@@ -185,6 +228,23 @@ class PrefixKeyScanStateEncoder(
}
override def supportPrefixKeyScan: Boolean = true
+
+ override def supportEventTimeIndex: Boolean = false
+
+ // FIXME: fix me!
+ def extractEventTime(key: UnsafeRow): Long = {
+ throw new IllegalStateException("This encoder doesn't support event time index!")
+ }
+
+ // FIXME: fix me!
+ def encodeEventTimeIndexKey(timestamp: Long, encodedKey: Array[Byte]): Array[Byte] = {
+ throw new IllegalStateException("This encoder doesn't support event time index!")
+ }
+
+ // FIXME: fix me!
+ def decodeEventTimeIndexKey(eventTimeBytes: Array[Byte]): (Long, Array[Byte]) = {
+ throw new IllegalStateException("This encoder doesn't support event time index!")
+ }
}
/**
@@ -197,8 +257,10 @@ class PrefixKeyScanStateEncoder(
* (offset 0 is the version byte of value 0). That is, if the unsafe row has N bytes,
* then the generated array byte will be N+1 bytes.
*/
-class NoPrefixKeyStateEncoder(keySchema: StructType, valueSchema: StructType)
- extends RocksDBStateEncoder {
+class NoPrefixKeyStateEncoder(
+ keySchema: StructType,
+ valueSchema: StructType,
+ eventTimeColIdx: Array[Int]) extends RocksDBStateEncoder {
import RocksDBStateEncoder._
@@ -207,6 +269,32 @@ class NoPrefixKeyStateEncoder(keySchema: StructType, valueSchema: StructType)
private val valueRow = new UnsafeRow(valueSchema.size)
private val rowTuple = new UnsafeRowPair()
+ validateColumnTypeOnEventTimeColumn()
+
+ private def validateColumnTypeOnEventTimeColumn(): Unit = {
+ if (eventTimeColIdx.nonEmpty) {
+ var curSchema: StructType = keySchema
+ eventTimeColIdx.dropRight(1).foreach { idx =>
+ curSchema(idx).dataType match {
+ case stType: StructType =>
+ curSchema = stType
+ case _ =>
+ // FIXME: better error message
+ throw new IllegalStateException("event time column is not properly specified! " +
+ s"index: ${eventTimeColIdx.mkString("(", ", ", ")")} / key schema: $keySchema")
+ }
+ }
+
+ curSchema(eventTimeColIdx.last).dataType match {
+ case _: TimestampType =>
+ case _ =>
+ // FIXME: better error message
+ throw new IllegalStateException("event time column is not properly specified! " +
+ s"index: ${eventTimeColIdx.mkString("(", ", ", ")")} / key schema: $keySchema")
+ }
+ }
+ }
+
override def encodeKey(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)
override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)
@@ -249,4 +337,41 @@ class NoPrefixKeyStateEncoder(keySchema: StructType, valueSchema: StructType)
override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = {
throw new IllegalStateException("This encoder doesn't support prefix key!")
}
+
+ override def supportEventTimeIndex: Boolean = eventTimeColIdx.nonEmpty
+
+ override def extractEventTime(key: UnsafeRow): Long = {
+ var curRow: UnsafeRow = key
+ var curSchema: StructType = keySchema
+
+ eventTimeColIdx.dropRight(1).foreach { idx =>
+ // validation is done in initialization phase
+ curSchema = curSchema(idx).dataType.asInstanceOf[StructType]
+ curRow = curRow.getStruct(idx, curSchema.length)
+ }
+
+ curRow.getLong(eventTimeColIdx.last) / 1000
+ }
+
+ override def encodeEventTimeIndexKey(timestamp: Long, encodedKey: Array[Byte]): Array[Byte] = {
+ val newKey = new Array[Byte](8 + encodedKey.length)
+
+ Platform.putLong(newKey, Platform.BYTE_ARRAY_OFFSET,
+ BinarySortable.encodeToBinarySortableLong(timestamp))
+ Platform.copyMemory(encodedKey, Platform.BYTE_ARRAY_OFFSET, newKey,
+ Platform.BYTE_ARRAY_OFFSET + 8, encodedKey.length)
+
+ newKey
+ }
+
+ override def decodeEventTimeIndexKey(eventTimeBytes: Array[Byte]): (Long, Array[Byte]) = {
+ val encoded = Platform.getLong(eventTimeBytes, Platform.BYTE_ARRAY_OFFSET)
+ val timestamp = BinarySortable.decodeBinarySortableLong(encoded)
+
+ val key = new Array[Byte](eventTimeBytes.length - 8)
+ Platform.copyMemory(eventTimeBytes, Platform.BYTE_ARRAY_OFFSET + 8,
+ key, Platform.BYTE_ARRAY_OFFSET, eventTimeBytes.length - 8)
+
+ (timestamp, key)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
index a2b33c2..1d66220 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
@@ -25,7 +25,8 @@ import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.types.StructType
-import org.apache.spark.util.Utils
+import org.apache.spark.unsafe.Platform
+import org.apache.spark.util.{NextIterator, Utils}
private[sql] class RocksDBStateStoreProvider
extends StateStoreProvider with Logging with Closeable {
@@ -60,12 +61,23 @@ private[sql] class RocksDBStateStoreProvider
verify(state == UPDATING, "Cannot put after already committed or aborted")
verify(key != null, "Key cannot be null")
require(value != null, "Cannot put a null value")
- rocksDB.put(encoder.encodeKey(key), encoder.encodeValue(value))
+
+ val encodedKey = encoder.encodeKey(key)
+ val encodedValue = encoder.encodeValue(value)
+ rocksDB.put(encodedKey, encodedValue)
+
+ if (encoder.supportEventTimeIndex) {
+ val timestamp = encoder.extractEventTime(key)
+ val tsKey = encoder.encodeEventTimeIndexKey(timestamp, encodedKey)
+
+ rocksDB.put(CF_EVENT_TIME_INDEX, tsKey, Array.empty)
+ }
}
override def remove(key: UnsafeRow): Unit = {
verify(state == UPDATING, "Cannot remove after already committed or aborted")
verify(key != null, "Key cannot be null")
+ // FIXME: this should reflect the index
rocksDB.remove(encoder.encodeKey(key))
}
@@ -161,13 +173,75 @@ private[sql] class RocksDBStateStoreProvider
/** Return the [[RocksDB]] instance in this store. This is exposed mainly for testing. */
def dbInstance(): RocksDB = rocksDB
+
+ /** FIXME: method doc */
+ override def evictOnWatermark(
+ watermarkMs: Long,
+ altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair] = {
+
+ if (encoder.supportEventTimeIndex) {
+ val kv = new ByteArrayPair()
+
+ // FIXME: DEBUG
+ // logWarning(s"DEBUG: start iterating event time index, watermark $watermarkMs")
+
+ new NextIterator[UnsafeRowPair] {
+ private val iter = rocksDB.iterator(CF_EVENT_TIME_INDEX)
+ override protected def getNext(): UnsafeRowPair = {
+ if (iter.hasNext) {
+ val pair = iter.next()
+
+ val encodedTs = Platform.getLong(pair.key, Platform.BYTE_ARRAY_OFFSET)
+ val decodedTs = BinarySortable.decodeBinarySortableLong(encodedTs)
+
+ // FIXME: DEBUG
+ // logWarning(s"DEBUG: decoded TS: $decodedTs")
+
+ if (decodedTs > watermarkMs) {
+ finished = true
+ null
+ } else {
+ // FIXME: can we leverage deleteRange to bulk delete on index?
+ rocksDB.remove(CF_EVENT_TIME_INDEX, pair.key)
+ val (_, encodedKey) = encoder.decodeEventTimeIndexKey(pair.key)
+ val value = rocksDB.get(encodedKey)
+ if (value == null) {
+ throw new IllegalStateException("Event time index has been broken!")
+ }
+ kv.set(encodedKey, value)
+ val rowPair = encoder.decode(kv)
+ rocksDB.remove(encodedKey)
+ rowPair
+ }
+ } else {
+ finished = true
+ null
+ }
+ }
+
+ override protected def close(): Unit = {
+ iter.closeIfNeeded()
+ }
+ }
+ } else {
+ rocksDB.iterator().flatMap { kv =>
+ val rowPair = encoder.decode(kv)
+ if (altPred(rowPair)) {
+ rocksDB.remove(kv.key)
+ Some(rowPair)
+ } else {
+ None
+ }
+ }
+ }
+ }
}
override def init(
stateStoreId: StateStoreId,
keySchema: StructType,
valueSchema: StructType,
- numColsPrefixKey: Int,
+ operatorContext: StatefulOperatorContext,
storeConf: StateStoreConf,
hadoopConf: Configuration): Unit = {
this.stateStoreId_ = stateStoreId
@@ -176,11 +250,14 @@ private[sql] class RocksDBStateStoreProvider
this.storeConf = storeConf
this.hadoopConf = hadoopConf
- require((keySchema.length == 0 && numColsPrefixKey == 0) ||
- (keySchema.length > numColsPrefixKey), "The number of columns in the key must be " +
- "greater than the number of columns for prefix key!")
+ require((keySchema.length == 0 && operatorContext.numColsPrefixKey == 0) ||
+ (keySchema.length > operatorContext.numColsPrefixKey), "The number of columns in the key " +
+ "must be greater than the number of columns for prefix key!")
- this.encoder = RocksDBStateEncoder.getEncoder(keySchema, valueSchema, numColsPrefixKey)
+ this.operatorContext = operatorContext
+
+ this.encoder = RocksDBStateEncoder.getEncoder(keySchema, valueSchema,
+ operatorContext.numColsPrefixKey, operatorContext.eventTimeColIdx)
rocksDB // lazy initialization
}
@@ -212,6 +289,7 @@ private[sql] class RocksDBStateStoreProvider
@volatile private var valueSchema: StructType = _
@volatile private var storeConf: StateStoreConf = _
@volatile private var hadoopConf: Configuration = _
+ @volatile private var operatorContext: StatefulOperatorContext = _
private[sql] lazy val rocksDB = {
val dfsRootDir = stateStoreId.storeCheckpointLocation().toString
@@ -219,7 +297,10 @@ private[sql] class RocksDBStateStoreProvider
s"partId=${stateStoreId.partitionId},name=${stateStoreId.storeName})"
val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
val localRootDir = Utils.createTempDir(Utils.getLocalDir(sparkConf), storeIdStr)
- new RocksDB(dfsRootDir, RocksDBConf(storeConf), localRootDir, hadoopConf, storeIdStr)
+ new RocksDB(dfsRootDir, RocksDBConf(storeConf),
+ columnFamilies = Seq("default", RocksDBStateStoreProvider.CF_EVENT_TIME_INDEX),
+ localRootDir = localRootDir,
+ hadoopConf = hadoopConf, loggingId = storeIdStr)
}
@volatile private var encoder: RocksDBStateEncoder = _
@@ -234,6 +315,9 @@ object RocksDBStateStoreProvider {
val STATE_ENCODING_NUM_VERSION_BYTES = 1
val STATE_ENCODING_VERSION: Byte = 0
+ // reserved column families
+ val CF_EVENT_TIME_INDEX: String = "__event_time_idx"
+
// Native operation latencies report as latency in microseconds
// as SQLMetrics support millis. Convert the value to millis
val CUSTOM_METRIC_GET_TIME = StateStoreCustomTimingMetric(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index 5020638..cee9ad1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -130,6 +130,11 @@ trait StateStore extends ReadStateStore {
*/
override def iterator(): Iterator[UnsafeRowPair]
+ /** FIXME: method doc */
+ def evictOnWatermark(
+ watermarkMs: Long,
+ altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair]
+
/** Current metrics of the state store */
def metrics: StateStoreMetrics
@@ -229,6 +234,19 @@ class InvalidUnsafeRowException
"checkpoint or use the legacy Spark version to process the streaming state.", null)
/**
+ * FIXME: classdoc
+ *
+ * @param numColsPrefixKey The number of leftmost columns to be used as prefix key.
+ * A value not greater than 0 means the operator doesn't activate prefix
+ * key, and the operator should not call prefixScan method in StateStore.
+ * @param eventTimeColIdx column specifying event time for the row. only works when the column
+ * is in the key. array type as the column can be struct type.
+ */
+case class StatefulOperatorContext(
+ numColsPrefixKey: Int = 0,
+ eventTimeColIdx: Array[Int] = Array.empty)
+
+/**
* Trait representing a provider that provide [[StateStore]] instances representing
* versions of state data.
*
@@ -255,9 +273,7 @@ trait StateStoreProvider {
* @param stateStoreId Id of the versioned StateStores that this provider will generate
* @param keySchema Schema of keys to be stored
* @param valueSchema Schema of value to be stored
- * @param numColsPrefixKey The number of leftmost columns to be used as prefix key.
- * A value not greater than 0 means the operator doesn't activate prefix
- * key, and the operator should not call prefixScan method in StateStore.
+ * @param operatorContext FIXME: ...
* @param storeConfs Configurations used by the StateStores
* @param hadoopConf Hadoop configuration that could be used by StateStore to save state data
*/
@@ -265,7 +281,7 @@ trait StateStoreProvider {
stateStoreId: StateStoreId,
keySchema: StructType,
valueSchema: StructType,
- numColsPrefixKey: Int,
+ operatorContext: StatefulOperatorContext,
storeConfs: StateStoreConf,
hadoopConf: Configuration): Unit
@@ -318,11 +334,11 @@ object StateStoreProvider {
providerId: StateStoreProviderId,
keySchema: StructType,
valueSchema: StructType,
- numColsPrefixKey: Int,
+ operatorContext: StatefulOperatorContext,
storeConf: StateStoreConf,
hadoopConf: Configuration): StateStoreProvider = {
val provider = create(storeConf.providerClass)
- provider.init(providerId.storeId, keySchema, valueSchema, numColsPrefixKey,
+ provider.init(providerId.storeId, keySchema, valueSchema, operatorContext,
storeConf, hadoopConf)
provider
}
@@ -471,13 +487,13 @@ object StateStore extends Logging {
storeProviderId: StateStoreProviderId,
keySchema: StructType,
valueSchema: StructType,
- numColsPrefixKey: Int,
+ operatorContext: StatefulOperatorContext,
version: Long,
storeConf: StateStoreConf,
hadoopConf: Configuration): ReadStateStore = {
require(version >= 0)
val storeProvider = getStateStoreProvider(storeProviderId, keySchema, valueSchema,
- numColsPrefixKey, storeConf, hadoopConf)
+ operatorContext, storeConf, hadoopConf)
storeProvider.getReadStore(version)
}
@@ -486,13 +502,13 @@ object StateStore extends Logging {
storeProviderId: StateStoreProviderId,
keySchema: StructType,
valueSchema: StructType,
- numColsPrefixKey: Int,
+ operatorContext: StatefulOperatorContext,
version: Long,
storeConf: StateStoreConf,
hadoopConf: Configuration): StateStore = {
require(version >= 0)
val storeProvider = getStateStoreProvider(storeProviderId, keySchema, valueSchema,
- numColsPrefixKey, storeConf, hadoopConf)
+ operatorContext, storeConf, hadoopConf)
storeProvider.getStore(version)
}
@@ -500,7 +516,7 @@ object StateStore extends Logging {
storeProviderId: StateStoreProviderId,
keySchema: StructType,
valueSchema: StructType,
- numColsPrefixKey: Int,
+ operatorContext: StatefulOperatorContext,
storeConf: StateStoreConf,
hadoopConf: Configuration): StateStoreProvider = {
loadedProviders.synchronized {
@@ -527,7 +543,7 @@ object StateStore extends Logging {
val provider = loadedProviders.getOrElseUpdate(
storeProviderId,
StateStoreProvider.createAndInit(
- storeProviderId, keySchema, valueSchema, numColsPrefixKey, storeConf, hadoopConf)
+ storeProviderId, keySchema, valueSchema, operatorContext, storeConf, hadoopConf)
)
val otherProviderIds = loadedProviders.keys.filter(_ != storeProviderId).toSeq
val providerIdsToUnload = reportActiveStoreInstance(storeProviderId, otherProviderIds)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
index fbe83ad..33d83b7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
@@ -74,7 +74,7 @@ class ReadStateStoreRDD[T: ClassTag, U: ClassTag](
storeVersion: Long,
keySchema: StructType,
valueSchema: StructType,
- numColsPrefixKey: Int,
+ operatorContext: StatefulOperatorContext,
sessionState: SessionState,
@transient private val storeCoordinator: Option[StateStoreCoordinatorRef],
extraOptions: Map[String, String] = Map.empty)
@@ -87,7 +87,7 @@ class ReadStateStoreRDD[T: ClassTag, U: ClassTag](
val storeProviderId = getStateProviderId(partition)
val store = StateStore.getReadOnly(
- storeProviderId, keySchema, valueSchema, numColsPrefixKey, storeVersion,
+ storeProviderId, keySchema, valueSchema, operatorContext, storeVersion,
storeConf, hadoopConfBroadcast.value.value)
val inputIter = dataRDD.iterator(partition, ctxt)
storeReadFunction(store, inputIter)
@@ -108,7 +108,7 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
storeVersion: Long,
keySchema: StructType,
valueSchema: StructType,
- numColsPrefixKey: Int,
+ operatorContext: StatefulOperatorContext,
sessionState: SessionState,
@transient private val storeCoordinator: Option[StateStoreCoordinatorRef],
extraOptions: Map[String, String] = Map.empty)
@@ -121,7 +121,7 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
val storeProviderId = getStateProviderId(partition)
val store = StateStore.get(
- storeProviderId, keySchema, valueSchema, numColsPrefixKey, storeVersion,
+ storeProviderId, keySchema, valueSchema, operatorContext, storeVersion,
storeConf, hadoopConfBroadcast.value.value)
val inputIter = dataRDD.iterator(partition, ctxt)
storeUpdateFunction(store, inputIter)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala
index 36138f1..c6b63cf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala
@@ -51,6 +51,12 @@ sealed trait StreamingAggregationStateManager extends Serializable {
/** Remove a single non-null key from the target state store. */
def remove(store: StateStore, key: UnsafeRow): Unit
+ // FIXME: method doc!
+ def evictOnWatermark(
+ store: StateStore,
+ watermarkMs: Long,
+ altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair]
+
/** Return an iterator containing all the key-value pairs in target state store. */
def iterator(store: ReadStateStore): Iterator[UnsafeRowPair]
@@ -128,6 +134,13 @@ class StreamingAggregationStateManagerImplV1(
override def values(store: ReadStateStore): Iterator[UnsafeRow] = {
store.iterator().map(_.value)
}
+
+ override def evictOnWatermark(
+ store: StateStore,
+ watermarkMs: Long,
+ altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair] = {
+ store.evictOnWatermark(watermarkMs, altPred)
+ }
}
/**
@@ -186,6 +199,16 @@ class StreamingAggregationStateManagerImplV2(
store.iterator().map(rowPair => new UnsafeRowPair(rowPair.key, restoreOriginalRow(rowPair)))
}
+
+ override def evictOnWatermark(
+ store: StateStore,
+ watermarkMs: Long,
+ altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair] = {
+ store.evictOnWatermark(watermarkMs, altPred).map { rowPair =>
+ new UnsafeRowPair(rowPair.key, restoreOriginalRow(rowPair))
+ }
+ }
+
override def values(store: ReadStateStore): Iterator[UnsafeRow] = {
store.iterator().map(rowPair => restoreOriginalRow(rowPair))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
index f301d23..ce56845 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
@@ -363,8 +363,9 @@ class SymmetricHashJoinStateManager(
protected def getStateStore(keySchema: StructType, valueSchema: StructType): StateStore = {
val storeProviderId = StateStoreProviderId(
stateInfo.get, partitionId, getStateStoreName(joinSide, stateStoreType))
+ // FIXME: would setting prefixScan / evict help?
val store = StateStore.get(
- storeProviderId, keySchema, valueSchema, numColsPrefixKey = 0,
+ storeProviderId, keySchema, valueSchema, StatefulOperatorContext(),
stateInfo.get.storeVersion, storeConf, hadoopConf)
logInfo(s"Loaded store ${store.id}")
store
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
index 01ff72b..7cd25eb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
@@ -35,14 +35,14 @@ package object state {
stateInfo: StatefulOperatorStateInfo,
keySchema: StructType,
valueSchema: StructType,
- numColsPrefixKey: Int)(
+ operatorContext: StatefulOperatorContext)(
storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = {
mapPartitionsWithStateStore(
stateInfo,
keySchema,
valueSchema,
- numColsPrefixKey,
+ operatorContext,
sqlContext.sessionState,
Some(sqlContext.streams.stateStoreCoordinator))(
storeUpdateFunction)
@@ -53,7 +53,7 @@ package object state {
stateInfo: StatefulOperatorStateInfo,
keySchema: StructType,
valueSchema: StructType,
- numColsPrefixKey: Int,
+ operatorContext: StatefulOperatorContext,
sessionState: SessionState,
storeCoordinator: Option[StateStoreCoordinatorRef],
extraOptions: Map[String, String] = Map.empty)(
@@ -77,7 +77,7 @@ package object state {
stateInfo.storeVersion,
keySchema,
valueSchema,
- numColsPrefixKey,
+ operatorContext,
sessionState,
storeCoordinator,
extraOptions)
@@ -88,7 +88,7 @@ package object state {
stateInfo: StatefulOperatorStateInfo,
keySchema: StructType,
valueSchema: StructType,
- numColsPrefixKey: Int,
+ operatorContext: StatefulOperatorContext,
sessionState: SessionState,
storeCoordinator: Option[StateStoreCoordinatorRef],
extraOptions: Map[String, String] = Map.empty)(
@@ -112,7 +112,7 @@ package object state {
stateInfo.storeVersion,
keySchema,
valueSchema,
- numColsPrefixKey,
+ operatorContext,
sessionState,
storeCoordinator,
extraOptions)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index 3431823..923d94e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -237,11 +237,10 @@ trait WatermarkSupport extends SparkPlan {
protected def removeKeysOlderThanWatermark(store: StateStore): Unit = {
if (watermarkPredicateForKeys.nonEmpty) {
val numRemovedStateRows = longMetric("numRemovedStateRows")
- store.iterator().foreach { rowPair =>
- if (watermarkPredicateForKeys.get.eval(rowPair.key)) {
- store.remove(rowPair.key)
- numRemovedStateRows += 1
- }
+ store.evictOnWatermark(eventTimeWatermark.get, pair => {
+ watermarkPredicateForKeys.get.eval(pair.key)
+ }).foreach { _ =>
+ numRemovedStateRows += 1
}
}
}
@@ -251,11 +250,10 @@ trait WatermarkSupport extends SparkPlan {
store: StateStore): Unit = {
if (watermarkPredicateForKeys.nonEmpty) {
val numRemovedStateRows = longMetric("numRemovedStateRows")
- storeManager.keys(store).foreach { keyRow =>
- if (watermarkPredicateForKeys.get.eval(keyRow)) {
- storeManager.remove(store, keyRow)
- numRemovedStateRows += 1
- }
+ storeManager.evictOnWatermark(store,
+ eventTimeWatermark.get, pair => watermarkPredicateForKeys.get.eval(pair.key)
+ ).foreach { _ =>
+ numRemovedStateRows += 1
}
}
}
@@ -307,7 +305,8 @@ case class StateStoreRestoreExec(
getStateInfo,
keyExpressions.toStructType,
stateManager.getStateValueSchema,
- numColsPrefixKey = 0,
+ // FIXME: set event time column here!
+ StatefulOperatorContext(),
session.sessionState,
Some(session.streams.stateStoreCoordinator)) { case (store, iter) =>
val hasInput = iter.hasNext
@@ -365,11 +364,26 @@ case class StateStoreSaveExec(
assert(outputMode.nonEmpty,
"Incorrect planning in IncrementalExecution, outputMode has not been set")
+ val eventTimeIdx = keyExpressions.indexWhere(_.metadata.contains(EventTimeWatermark.delayKey))
+ val eventTimeColIdx = if (eventTimeIdx >= 0) {
+ keyExpressions.toStructType(eventTimeIdx).dataType match {
+ // FIXME: for now, we only consider window operation here, as we do the same in
+ // WatermarkSupport.watermarkExpression
+ case StructType(_) => Array[Int](eventTimeIdx, 1)
+ case TimestampType => Array[Int](eventTimeIdx)
+ case _ => throw new IllegalStateException(
+ "The type of event time column should be timestamp")
+ }
+ } else {
+ Array.empty[Int]
+ }
+
child.execute().mapPartitionsWithStateStore(
getStateInfo,
keyExpressions.toStructType,
stateManager.getStateValueSchema,
- numColsPrefixKey = 0,
+ // FIXME: set event time column here!
+ StatefulOperatorContext(eventTimeColIdx = eventTimeColIdx),
session.sessionState,
Some(session.streams.stateStoreCoordinator)) { (store, iter) =>
val numOutputRows = longMetric("numOutputRows")
@@ -414,18 +428,16 @@ case class StateStoreSaveExec(
}
val removalStartTimeNs = System.nanoTime
- val rangeIter = stateManager.iterator(store)
+ val evictedIter = stateManager.evictOnWatermark(store,
+ eventTimeWatermark.get, pair => watermarkPredicateForKeys.get.eval(pair.key))
new NextIterator[InternalRow] {
override protected def getNext(): InternalRow = {
var removedValueRow: InternalRow = null
- while(rangeIter.hasNext && removedValueRow == null) {
- val rowPair = rangeIter.next()
- if (watermarkPredicateForKeys.get.eval(rowPair.key)) {
- stateManager.remove(store, rowPair.key)
- numRemovedStateRows += 1
- removedValueRow = rowPair.value
- }
+ while(evictedIter.hasNext && removedValueRow == null) {
+ val rowPair = evictedIter.next()
+ numRemovedStateRows += 1
+ removedValueRow = rowPair.value
}
if (removedValueRow == null) {
finished = true
@@ -541,7 +553,8 @@ case class SessionWindowStateStoreRestoreExec(
getStateInfo,
stateManager.getStateKeySchema,
stateManager.getStateValueSchema,
- numColsPrefixKey = stateManager.getNumColsForPrefixKey,
+ // FIXME: set event time column here!
+ StatefulOperatorContext(stateManager.getNumColsForPrefixKey),
session.sessionState,
Some(session.streams.stateStoreCoordinator)) { case (store, iter) =>
@@ -618,7 +631,8 @@ case class SessionWindowStateStoreSaveExec(
getStateInfo,
stateManager.getStateKeySchema,
stateManager.getStateValueSchema,
- numColsPrefixKey = stateManager.getNumColsForPrefixKey,
+ // FIXME: set event time column!
+ StatefulOperatorContext(numColsPrefixKey = stateManager.getNumColsForPrefixKey),
session.sessionState,
Some(session.streams.stateStoreCoordinator)) { case (store, iter) =>
@@ -652,6 +666,7 @@ case class SessionWindowStateStoreSaveExec(
val removalStartTimeNs = System.nanoTime
new NextIterator[InternalRow] {
+ // FIXME: can we optimize this case as well?
private val removedIter = stateManager.removeByValueCondition(
store, watermarkPredicateForData.get.eval)
@@ -751,7 +766,8 @@ case class StreamingDeduplicateExec(
getStateInfo,
keyExpressions.toStructType,
child.output.toStructType,
- numColsPrefixKey = 0,
+ // FIXME: set event time column!
+ StatefulOperatorContext(),
session.sessionState,
Some(session.streams.stateStoreCoordinator),
// We won't check value row in state store since the value StreamingDeduplicateExec.EMPTY_ROW
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala
index 8bba9b8..ba15f50 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow, SortOrder, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, Partitioning}
import org.apache.spark.sql.execution.{LimitExec, SparkPlan, UnaryExecNode}
-import org.apache.spark.sql.execution.streaming.state.StateStoreOps
+import org.apache.spark.sql.execution.streaming.state.{StatefulOperatorContext, StateStoreOps}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{LongType, NullType, StructField, StructType}
import org.apache.spark.util.{CompletionIterator, NextIterator}
@@ -52,7 +52,8 @@ case class StreamingGlobalLimitExec(
getStateInfo,
keySchema,
valueSchema,
- numColsPrefixKey = 0,
+ // FIXME: set event time column!
+ StatefulOperatorContext(),
session.sessionState,
Some(session.streams.stateStoreCoordinator)) { (store, iter) =>
val key = UnsafeProjection.create(keySchema)(new GenericInternalRow(Array[Any](null)))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBenchmark.scala
new file mode 100644
index 0000000..9a83f5c
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/StateStoreBenchmark.scala
@@ -0,0 +1,633 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.benchmark
+
+import java.{util => jutil}
+
+import scala.util.Random
+
+import org.apache.hadoop.conf.Configuration
+import org.rocksdb.RocksDBException
+
+import org.apache.spark.benchmark.Benchmark
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider, StatefulOperatorContext, StateStore, StateStoreConf, StateStoreId, StateStoreProvider}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType, TimestampType}
+import org.apache.spark.util.Utils
+
+/**
+ * Synthetic benchmark for State Store operations.
+ * To run this benchmark:
+ * {{{
+ * 1. without sbt:
+ * bin/spark-submit --class <this class>
+ * --jars <spark core test jar>,<spark catalyst test jar> <sql core test jar>
+ * 2. build/sbt "sql/test:runMain <this class>"
+ * 3. generate result:
+ * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain <this class>"
+ * Results will be written to "benchmarks/StateStoreBenchmark-results.txt".
+ * }}}
+ */
+object StateStoreBenchmark extends SqlBasedBenchmark {
+
+ private val numOfRows: Seq[Int] = Seq(10000, 50000, 100000) // Seq(10000, 100000, 1000000)
+
+ // 200%, 100%, 50%, 25%, 10%, 5%, 1%, no update
+ // rate is relative to the number of rows in prev. batch
+ private val updateRates: Seq[Int] = Seq(25, 10, 5) // Seq(200, 100, 50, 25, 10, 5, 1, 0)
+
+ // 100%, 50%, 25%, 10%, 5%, 1%, no evict
+ // rate is relative to the number of rows in prev. batch
+ private val evictRates: Seq[Int] = Seq(100, 50, 25, 10, 5, 1, 0)
+
+ private val keySchema = StructType(
+ Seq(StructField("key1", IntegerType, true), StructField("key2", TimestampType, true)))
+ private val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
+
+ private val keyProjection = UnsafeProjection.create(keySchema)
+ private val valueProjection = UnsafeProjection.create(valueSchema)
+
+ private def runEvictBenchmark(): Unit = {
+ runBenchmark("evict rows") {
+ val numOfRows = Seq(10000) // Seq(1000, 10000, 100000)
+ val numOfTimestamps = Seq(10, 100, 1000)
+ val numOfEvictionRates = Seq(50, 25, 10, 5, 1, 0) // Seq(100, 75, 50, 25, 1, 0)
+
+ numOfRows.foreach { numOfRow =>
+ numOfTimestamps.foreach { numOfTimestamp =>
+ val timestampsInMicros = (0L until numOfTimestamp).map(ts => ts * 1000L).toList
+
+ val testData = constructRandomizedTestData(numOfRow, timestampsInMicros, 0)
+
+ val rocksDBProvider = newRocksDBStateProviderWithEventTimeIdx()
+ val rocksDBStore = rocksDBProvider.getStore(0)
+ updateRows(rocksDBStore, testData)
+
+ val committedVersion = try {
+ rocksDBStore.commit()
+ } catch {
+ case exc: RocksDBException =>
+ // scalastyle:off println
+ System.out.println(s"Exception in RocksDB happen! ${exc.getMessage} / " +
+ s"status: ${exc.getStatus.getState} / ${exc.getStatus.getCodeString}" )
+ exc.printStackTrace()
+ throw exc
+ // scalastyle:on println
+ }
+
+ numOfEvictionRates.foreach { numOfEvictionRate =>
+ val numOfRowsToEvict = numOfRow * numOfEvictionRate / 100
+ // scalastyle:off println
+ System.out.println(s"numOfRowsToEvict: $numOfRowsToEvict / " +
+ s"timestampsInMicros: $timestampsInMicros / " +
+ s"numOfEvictionRate: $numOfEvictionRate / " +
+ s"numOfTimestamp: $numOfTimestamp / " +
+ s"take: ${numOfTimestamp * numOfEvictionRate / 100}")
+
+ // scalastyle:on println
+ val maxTimestampToEvictInMillis = timestampsInMicros
+ .take(numOfTimestamp * numOfEvictionRate / 100)
+ .lastOption.map(_ / 1000).getOrElse(-1L)
+
+ val benchmark = new Benchmark(s"evicting $numOfRowsToEvict rows " +
+ s"(max timestamp to evict in millis: $maxTimestampToEvictInMillis) " +
+ s"from $numOfRow rows with $numOfTimestamp timestamps " +
+ s"(${numOfRow / numOfTimestamp} rows" +
+ s" for the same timestamp)",
+ numOfRow, minNumIters = 1000, output = output)
+
+ benchmark.addTimerCase("RocksDBStateStoreProvider") { timer =>
+ val rocksDBStore = rocksDBProvider.getStore(committedVersion)
+
+ timer.startTiming()
+ evictAsFullScanAndRemove(rocksDBStore, maxTimestampToEvictInMillis)
+ timer.stopTiming()
+
+ rocksDBStore.abort()
+ }
+
+ benchmark.addTimerCase("RocksDBStateStoreProvider with event time idx") { timer =>
+ val rocksDBStore = rocksDBProvider.getStore(committedVersion)
+
+ timer.startTiming()
+ evictAsNewEvictApi(rocksDBStore, maxTimestampToEvictInMillis)
+ timer.stopTiming()
+
+ rocksDBStore.abort()
+ }
+
+ benchmark.run()
+ }
+
+ rocksDBProvider.close()
+ }
+ }
+ }
+ }
+
+ private def runPutBenchmark(): Unit = {
+ runBenchmark("put rows") {
+ val numOfRows = Seq(10000) // Seq(1000, 10000, 100000)
+ val numOfTimestamps = Seq(100, 1000, 10000) // Seq(1, 10, 100, 1000, 10000)
+ numOfRows.foreach { numOfRow =>
+ numOfTimestamps.foreach { numOfTimestamp =>
+ val timestamps = (0L until numOfTimestamp).map(ts => ts * 1000L).toList
+
+ val testData = constructRandomizedTestData(numOfRow, timestamps, 0)
+
+ val rocksDBProvider = newRocksDBStateProvider()
+ val rocksDBWithIdxProvider = newRocksDBStateProviderWithEventTimeIdx()
+
+ val benchmark = new Benchmark(s"putting $numOfRow rows, with $numOfTimestamp " +
+ s"timestamps (${numOfRow / numOfTimestamp} rows for the same timestamp)",
+ numOfRow, minNumIters = 1000, output = output)
+
+ benchmark.addTimerCase("RocksDBStateStoreProvider") { timer =>
+ val rocksDBStore = rocksDBProvider.getStore(0)
+
+ timer.startTiming()
+ updateRows(rocksDBStore, testData)
+ timer.stopTiming()
+
+ rocksDBStore.abort()
+ }
+
+ benchmark.addTimerCase("RocksDBStateStoreProvider with event time idx") { timer =>
+ val rocksDBWithIdxStore = rocksDBWithIdxProvider.getStore(0)
+
+ timer.startTiming()
+ updateRows(rocksDBWithIdxStore, testData)
+ timer.stopTiming()
+
+ rocksDBWithIdxStore.abort()
+ }
+
+ benchmark.run()
+
+ rocksDBProvider.close()
+ rocksDBWithIdxProvider.close()
+ }
+ }
+ }
+ }
+
+ override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
+ // runPutBenchmark()
+ runEvictBenchmark()
+
+ /*
+ val testData = constructRandomizedTestData(numOfRows.max)
+
+ skip("scanning and comparing") {
+ numOfRows.foreach { numOfRow =>
+ val curData = testData.take(numOfRow)
+
+ val inMemoryProvider = newHDFSBackedStateStoreProvider()
+ val inMemoryStore = inMemoryProvider.getStore(0)
+
+ val rocksDBProvider = newRocksDBStateProvider()
+ val rocksDBStore = rocksDBProvider.getStore(0)
+
+ updateRows(inMemoryStore, curData)
+ updateRows(rocksDBStore, curData)
+
+ val newVersionForInMemory = inMemoryStore.commit()
+ val newVersionForRocksDB = rocksDBStore.commit()
+
+ val benchmark = new Benchmark(s"scanning and comparing $numOfRow rows",
+ numOfRow, minNumIters = 1000, output = output)
+
+ benchmark.addTimerCase("HDFSBackedStateStoreProvider") { timer =>
+ val inMemoryStore2 = inMemoryProvider.getStore(newVersionForInMemory)
+
+ timer.startTiming()
+ // NOTE: the latency would be quite similar regardless of the rate of eviction
+ // as we don't remove the actual row, so I simply picked 10 %
+ fullScanAndCompareTimestamp(inMemoryStore2, (numOfRow * 0.1).toInt)
+ timer.stopTiming()
+
+ inMemoryStore2.abort()
+ }
+
+ benchmark.addTimerCase("RocksDBStateStoreProvider") { timer =>
+ val rocksDBStore2 = rocksDBProvider.getStore(newVersionForRocksDB)
+
+ timer.startTiming()
+ // NOTE: the latency would be quite similar regardless of the rate of eviction
+ // as we don't remove the actual row, so I simply picked 10 %
+ fullScanAndCompareTimestamp(rocksDBStore2, (numOfRow * 0.1).toInt)
+ timer.stopTiming()
+
+ rocksDBStore2.abort()
+ }
+
+ benchmark.run()
+
+ inMemoryProvider.close()
+ rocksDBProvider.close()
+ }
+ }
+
+ // runBenchmark("simulate full operations on eviction") {
+ skip("simulate full operations on eviction") {
+ numOfRows.foreach { numOfRow =>
+ val curData = testData.take(numOfRow)
+
+ val inMemoryProvider = newHDFSBackedStateStoreProvider()
+ val inMemoryStore = inMemoryProvider.getStore(0)
+
+ val rocksDBProvider = newRocksDBStateProvider()
+ val rocksDBStore = rocksDBProvider.getStore(0)
+
+ val indexForInMemoryStore = new jutil.concurrent.ConcurrentSkipListMap[
+ Int, jutil.List[UnsafeRow]]()
+ val indexForRocksDBStore = new jutil.concurrent.ConcurrentSkipListMap[
+ Int, jutil.List[UnsafeRow]]()
+
+ updateRowsWithSortedMapIndex(inMemoryStore, indexForInMemoryStore, curData)
+ updateRowsWithSortedMapIndex(rocksDBStore, indexForRocksDBStore, curData)
+
+ assert(indexForInMemoryStore.size() == numOfRow)
+ assert(indexForRocksDBStore.size() == numOfRow)
+
+ val newVersionForInMemory = inMemoryStore.commit()
+ val newVersionForRocksDB = rocksDBStore.commit()
+
+ val rowsToUpdate = constructRandomizedTestData(numOfRow / 100 * updateRates.max,
+ minIdx = numOfRow + 1)
+
+ updateRates.foreach { updateRate =>
+ val numRowsUpdate = numOfRow / 100 * updateRate
+ val curRowsToUpdate = rowsToUpdate.take(numRowsUpdate)
+
+ evictRates.foreach { evictRate =>
+ val maxIdxToEvict = numOfRow / 100 * evictRate
+
+ val benchmark = new Benchmark(s"simulating evict on $numOfRow rows, update " +
+ s"$numRowsUpdate rows ($updateRate %), evict $maxIdxToEvict rows ($evictRate %)",
+ numOfRow, minNumIters = 100, output = output)
+
+ benchmark.addTimerCase("HDFSBackedStateStoreProvider") { timer =>
+ val inMemoryStore2 = inMemoryProvider.getStore(newVersionForInMemory)
+
+ timer.startTiming()
+ updateRows(inMemoryStore2, curRowsToUpdate)
+ evictAsFullScanAndRemove(inMemoryStore2, maxIdxToEvict)
+ timer.stopTiming()
+
+ inMemoryStore2.abort()
+ }
+
+ benchmark.addTimerCase("HDFSBackedStateStoreProvider - sorted map index") { timer =>
+
+ val inMemoryStore2 = inMemoryProvider.getStore(newVersionForInMemory)
+
+ val curIndex = new jutil.concurrent.ConcurrentSkipListMap[Int,
+ jutil.List[UnsafeRow]]()
+ curIndex.putAll(indexForInMemoryStore)
+
+ assert(curIndex.size() == numOfRow)
+
+ timer.startTiming()
+ updateRowsWithSortedMapIndex(inMemoryStore2, curIndex, curRowsToUpdate)
+
+ assert(curIndex.size() == numOfRow + curRowsToUpdate.size)
+
+ evictAsScanSortedMapIndexAndRemove(inMemoryStore2, curIndex, maxIdxToEvict)
+ timer.stopTiming()
+
+ assert(curIndex.size() == numOfRow + curRowsToUpdate.size - maxIdxToEvict)
+
+ curIndex.clear()
+
+ inMemoryStore2.abort()
+ }
+
+ benchmark.run()
+
+ val benchmark2 = new Benchmark(s"simulating evict on $numOfRow rows, update " +
+ s"$numRowsUpdate rows ($updateRate %), evict $maxIdxToEvict rows ($evictRate %)",
+ numOfRow, minNumIters = 100, output = output)
+
+ benchmark2.addTimerCase("RocksDBStateStoreProvider") { timer =>
+ val rocksDBStore2 = rocksDBProvider.getStore(newVersionForRocksDB)
+
+ timer.startTiming()
+ updateRows(rocksDBStore2, curRowsToUpdate)
+ evictAsFullScanAndRemove(rocksDBStore2, maxIdxToEvict)
+ timer.stopTiming()
+
+ rocksDBStore2.abort()
+ }
+
+ benchmark2.addTimerCase("RocksDBStateStoreProvider - sorted map index") { timer =>
+
+ val rocksDBStore2 = rocksDBProvider.getStore(newVersionForRocksDB)
+
+ val curIndex = new jutil.concurrent.ConcurrentSkipListMap[Int,
+ jutil.List[UnsafeRow]]()
+ curIndex.putAll(indexForRocksDBStore)
+
+ assert(curIndex.size() == numOfRow)
+
+ timer.startTiming()
+ updateRowsWithSortedMapIndex(rocksDBStore2, curIndex, curRowsToUpdate)
+
+ assert(curIndex.size() == numOfRow + curRowsToUpdate.size)
+
+ evictAsScanSortedMapIndexAndRemove(rocksDBStore2, curIndex, maxIdxToEvict)
+ timer.stopTiming()
+
+ assert(curIndex.size() == numOfRow + curRowsToUpdate.size - maxIdxToEvict)
+
+ curIndex.clear()
+
+ rocksDBStore2.abort()
+ }
+
+ benchmark2.run()
+ }
+ }
+
+ inMemoryProvider.close()
+ rocksDBProvider.close()
+ }
+ }
+
+ // runBenchmark("simulate full operations on eviction") {
+ skip("simulate full operations on eviction - scannable index") {
+ numOfRows.foreach { numOfRow =>
+ val curData = testData.take(numOfRow)
+
+ val rocksDBProvider = newRocksDBStateProvider()
+ val rocksDBStore = rocksDBProvider.getStore(0)
+
+ val rocksDBWithIdxProvider = newRocksDBStateProviderWithEventTimeIdx()
+ val rocksDBWithIdxStore = rocksDBWithIdxProvider.getStore(0)
+
+ updateRows(rocksDBStore, curData)
+ updateRows(rocksDBWithIdxStore, curData)
+
+ val newVersionForRocksDB = rocksDBStore.commit()
+ val newVersionForRocksDBWithIdx = rocksDBWithIdxStore.commit()
+
+ val rowsToUpdate = constructRandomizedTestData(numOfRow / 100 * updateRates.max,
+ minIdx = numOfRow + 1)
+
+ updateRates.foreach { updateRate =>
+ val numRowsUpdate = numOfRow / 100 * updateRate
+ val curRowsToUpdate = rowsToUpdate.take(numRowsUpdate)
+
+ evictRates.foreach { evictRate =>
+ val maxIdxToEvict = numOfRow / 100 * evictRate
+
+ val benchmark = new Benchmark(s"simulating evict on $numOfRow rows, update " +
+ s"$numRowsUpdate rows ($updateRate %), evict $maxIdxToEvict rows ($evictRate %)",
+ numOfRow, minNumIters = 100, output = output)
+
+ benchmark.addTimerCase("RocksDBStateStoreProvider") { timer =>
+ val rocksDBStore2 = rocksDBProvider.getStore(newVersionForRocksDB)
+
+ timer.startTiming()
+ updateRows(rocksDBStore2, curRowsToUpdate)
+ evictAsFullScanAndRemove(rocksDBStore2, maxIdxToEvict)
+ // evictAsNewEvictApi(rocksDBStore2, maxIdxToEvict)
+ timer.stopTiming()
+
+ rocksDBStore2.abort()
+ }
+
+ benchmark.addTimerCase("RocksDBStateStoreProvider with event time idx") { timer =>
+ val rocksDBWithIdxStore2 = rocksDBWithIdxProvider.getStore(
+ newVersionForRocksDBWithIdx)
+
+ timer.startTiming()
+ updateRows(rocksDBWithIdxStore2, curRowsToUpdate)
+ evictAsNewEvictApi(rocksDBWithIdxStore2, maxIdxToEvict)
+ timer.stopTiming()
+
+ rocksDBWithIdxStore2.abort()
+ }
+
+ benchmark.run()
+ }
+ }
+
+ rocksDBProvider.close()
+ rocksDBWithIdxProvider.close()
+ }
+ }
+
+ runBenchmark("put rows") {
+ numOfRows.foreach { numOfRow =>
+ val curData = testData.take(numOfRow)
+
+ val rocksDBProvider = newRocksDBStateProvider()
+ val rocksDBWithIdxProvider = newRocksDBStateProviderWithEventTimeIdx()
+
+ val benchmark = new Benchmark(s"putting $numOfRow rows",
+ numOfRow, minNumIters = 1000, output = output)
+
+ benchmark.addTimerCase("RocksDBStateStoreProvider") { timer =>
+ val rocksDBStore = rocksDBProvider.getStore(0)
+
+ timer.startTiming()
+ updateRows(rocksDBStore, curData)
+ timer.stopTiming()
+
+ rocksDBStore.abort()
+ }
+
+ benchmark.addTimerCase("RocksDBStateStoreProvider with event time idx") { timer =>
+ val rocksDBWithIdxStore = rocksDBWithIdxProvider.getStore(0)
+
+ timer.startTiming()
+ updateRows(rocksDBWithIdxStore, curData)
+ timer.stopTiming()
+
+ rocksDBWithIdxStore.abort()
+ }
+
+ benchmark.run()
+
+ rocksDBProvider.close()
+ rocksDBWithIdxProvider.close()
+ }
+ }
+ */
+ }
+
+ final def skip(benchmarkName: String)(func: => Any): Unit = {
+ output.foreach(_.write(s"$benchmarkName is skipped".getBytes))
+ }
+
+ private def updateRows(
+ store: StateStore,
+ rows: Seq[(UnsafeRow, UnsafeRow)]): Unit = {
+ rows.foreach { case (key, value) =>
+ store.put(key, value)
+ }
+ }
+
+ private def evictAsFullScanAndRemove(
+ store: StateStore,
+ maxTimestampToEvict: Long): Unit = {
+ store.iterator().foreach { r =>
+ if (r.key.getLong(1) < maxTimestampToEvict) {
+ store.remove(r.key)
+ }
+ }
+ }
+
+ private def evictAsNewEvictApi(
+ store: StateStore,
+ maxTimestampToEvict: Long): Unit = {
+ store.evictOnWatermark(maxTimestampToEvict, pair => {
+ pair.key.getLong(1) < maxTimestampToEvict
+ }).foreach { _ => }
+ }
+
+ private def fullScanAndCompareTimestamp(
+ store: StateStore,
+ maxIdxToEvict: Int): Unit = {
+ var i: Long = 0
+ store.iterator().foreach { r =>
+ if (r.key.getInt(1) < maxIdxToEvict) {
+ // simply to avoid the "if statement" to be no-op
+ i += 1
+ }
+ }
+ }
+
+ private def updateRowsWithSortedMapIndex(
+ store: StateStore,
+ index: jutil.SortedMap[Int, jutil.List[UnsafeRow]],
+ rows: Seq[(UnsafeRow, UnsafeRow)]): Unit = {
+ rows.foreach { case (key, value) =>
+ val idx = key.getInt(1)
+
+ // TODO: rewrite this in atomic way?
+ if (index.containsKey(idx)) {
+ val list = index.get(idx)
+ list.add(key)
+ } else {
+ val list = new jutil.ArrayList[UnsafeRow]()
+ list.add(key)
+ index.put(idx, list)
+ }
+
+ store.put(key, value)
+ }
+ }
+
+ private def evictAsScanSortedMapIndexAndRemove(
+ store: StateStore,
+ index: jutil.SortedMap[Int, jutil.List[UnsafeRow]],
+ maxIdxToEvict: Int): Unit = {
+ val keysToRemove = index.headMap(maxIdxToEvict + 1)
+ val keysToRemoveIter = keysToRemove.entrySet().iterator()
+ while (keysToRemoveIter.hasNext) {
+ val entry = keysToRemoveIter.next()
+ val keys = entry.getValue
+ val keysIter = keys.iterator()
+ while (keysIter.hasNext) {
+ val key = keysIter.next()
+ store.remove(key)
+ }
+ keys.clear()
+ keysToRemoveIter.remove()
+ }
+ }
+
+ // FIXME: should the size of key / value be variables?
+ private def constructTestData(numRows: Int, minIdx: Int = 0): Seq[(UnsafeRow, UnsafeRow)] = {
+ (1 to numRows).map { idx =>
+ val keyRow = new GenericInternalRow(2)
+ keyRow.setInt(0, 1)
+ keyRow.setLong(1, (minIdx + idx) * 1000L) // microseconds
+ val valueRow = new GenericInternalRow(1)
+ valueRow.setInt(0, minIdx + idx)
+
+ val keyUnsafeRow = keyProjection(keyRow).copy()
+ val valueUnsafeRow = valueProjection(valueRow).copy()
+
+ (keyUnsafeRow, valueUnsafeRow)
+ }
+ }
+
+ // This prevents created keys to be in order, which may affect the performance on RocksDB.
+ private def constructRandomizedTestData(
+ numRows: Int,
+ timestamps: List[Long],
+ minIdx: Int = 0): Seq[(UnsafeRow, UnsafeRow)] = {
+ assert(numRows >= timestamps.length)
+ assert(numRows % timestamps.length == 0)
+
+ (1 to numRows).map { idx =>
+ val keyRow = new GenericInternalRow(2)
+ keyRow.setInt(0, Random.nextInt(Int.MaxValue))
+ keyRow.setLong(1, timestamps((minIdx + idx) % timestamps.length)) // microseconds
+ val valueRow = new GenericInternalRow(1)
+ valueRow.setInt(0, minIdx + idx)
+
+ val keyUnsafeRow = keyProjection(keyRow).copy()
+ val valueUnsafeRow = valueProjection(valueRow).copy()
+
+ (keyUnsafeRow, valueUnsafeRow)
+ }
+ }
+
+ private def newHDFSBackedStateStoreProvider(): StateStoreProvider = {
+ val storeId = StateStoreId(newDir(), Random.nextInt(), 0)
+ val provider = new HDFSBackedStateStoreProvider()
+ val sqlConf = new SQLConf()
+ sqlConf.setConfString("spark.sql.streaming.stateStore.compression.codec", "zstd")
+ val storeConf = new StateStoreConf(sqlConf)
+ provider.init(
+ storeId, keySchema, valueSchema, StatefulOperatorContext(),
+ storeConf, new Configuration)
+ provider
+ }
+
+ private def newRocksDBStateProvider(): StateStoreProvider = {
+ val storeId = StateStoreId(newDir(), Random.nextInt(), 0)
+ val provider = new RocksDBStateStoreProvider()
+ val sqlConf = new SQLConf()
+ sqlConf.setConfString("spark.sql.streaming.stateStore.compression.codec", "zstd")
+ val storeConf = new StateStoreConf(sqlConf)
+ provider.init(
+ storeId, keySchema, valueSchema, StatefulOperatorContext(),
+ storeConf, new Configuration)
+ provider
+ }
+
+ private def newRocksDBStateProviderWithEventTimeIdx(): StateStoreProvider = {
+ val storeId = StateStoreId(newDir(), Random.nextInt(), 0)
+ val provider = new RocksDBStateStoreProvider()
+ val sqlConf = new SQLConf()
+ sqlConf.setConfString("spark.sql.streaming.stateStore.compression.codec", "zstd")
+ val storeConf = new StateStoreConf(sqlConf)
+ provider.init(
+ storeId, keySchema, valueSchema, StatefulOperatorContext(eventTimeColIdx = Array(1)),
+ storeConf, new Configuration)
+ provider
+ }
+
+ private def newDir(): String = Utils.createTempDir().toString
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala
index 81f1a3f..4f7c040 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala
@@ -24,7 +24,7 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
-import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider, StateStore, StateStoreConf, StateStoreId, StateStoreProviderId, StreamingSessionWindowStateManager}
+import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider, StatefulOperatorContext, StateStore, StateStoreConf, StateStoreId, StateStoreProviderId, StreamingSessionWindowStateManager}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}
@@ -217,9 +217,12 @@ class MergingSortWithSessionWindowStateIteratorSuite extends StreamTest with Bef
stateFormatVersion)
val storeProviderId = StateStoreProviderId(stateInfo, 0, StateStoreId.DEFAULT_STORE_NAME)
+
+ // FIXME: event time column?
val store = StateStore.get(
storeProviderId, manager.getStateKeySchema, manager.getStateValueSchema,
- manager.getNumColsForPrefixKey, stateInfo.storeVersion, storeConf, new Configuration)
+ StatefulOperatorContext(numColsPrefixKey = manager.getNumColsForPrefixKey),
+ stateInfo.storeVersion, storeConf, new Configuration)
try {
f(manager, store)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala
index e52ccd0..4b53c0a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala
@@ -50,4 +50,18 @@ class MemoryStateStore extends StateStore() {
override def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair] = {
throw new UnsupportedOperationException("Doesn't support prefix scan!")
}
+
+ /** FIXME: method doc */
+ override def evictOnWatermark(
+ watermarkMs: Long,
+ altPred: UnsafeRowPair => Boolean): Iterator[UnsafeRowPair] = {
+ iterator().filter { pair =>
+ if (altPred.apply(pair)) {
+ remove(pair.key)
+ true
+ } else {
+ false
+ }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
index 2d741d3..7374303 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
@@ -24,7 +24,7 @@ import scala.collection.JavaConverters
import org.scalatest.time.{Minute, Span}
import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper}
-import org.apache.spark.sql.functions.count
+import org.apache.spark.sql.functions.{count, timestamp_seconds, window}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming._
@@ -52,6 +52,64 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest {
}
}
+ test("append mode") {
+ val inputData = MemoryStream[Int]
+ val conf = Map(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+ classOf[RocksDBStateStoreProvider].getName)
+
+ val windowedAggregation = inputData.toDF()
+ .withColumn("eventTime", timestamp_seconds($"value"))
+ .withWatermark("eventTime", "10 seconds")
+ .groupBy(window($"eventTime", "5 seconds") as 'window)
+ .agg(count("*") as 'count)
+ .select($"window".getField("start").cast("long").as[Long], $"count".as[Long])
+
+ testStream(windowedAggregation)(
+ StartStream(additionalConfs = conf),
+ AddData(inputData, 10, 11, 12, 13, 14, 15),
+ CheckNewAnswer(),
+ AddData(inputData, 25), // Advance watermark to 15 seconds
+ CheckNewAnswer((10, 5)),
+ // assertNumStateRows(2),
+ // assertNumRowsDroppedByWatermark(0),
+ AddData(inputData, 10), // Should not emit anything as data less than watermark
+ CheckNewAnswer()
+ // assertNumStateRows(2),
+ // assertNumRowsDroppedByWatermark(1)
+ )
+ }
+
+ test("update mode") {
+ val inputData = MemoryStream[Int]
+ val conf = Map(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+ classOf[RocksDBStateStoreProvider].getName)
+
+ val windowedAggregation = inputData.toDF()
+ .withColumn("eventTime", timestamp_seconds($"value"))
+ .withWatermark("eventTime", "10 seconds")
+ .groupBy(window($"eventTime", "5 seconds") as 'window)
+ .agg(count("*") as 'count)
+ .select($"window".getField("start").cast("long").as[Long], $"count".as[Long])
+
+ testStream(windowedAggregation, OutputMode.Update)(
+ StartStream(additionalConfs = conf),
+ AddData(inputData, 10, 11, 12, 13, 14, 15),
+ CheckNewAnswer((10, 5), (15, 1)),
+ AddData(inputData, 25), // Advance watermark to 15 seconds
+ CheckNewAnswer((25, 1)),
+ // assertNumStateRows(2),
+ // assertNumRowsDroppedByWatermark(0),
+ AddData(inputData, 10, 25), // Ignore 10 as its less than watermark
+ CheckNewAnswer((25, 2)),
+ // assertNumStateRows(2),
+ // assertNumRowsDroppedByWatermark(1),
+ AddData(inputData, 10), // Should not emit anything as data less than watermark
+ CheckNewAnswer()
+ // assertNumStateRows(2),
+ // assertNumRowsDroppedByWatermark(1)
+ )
+ }
+
test("SPARK-36236: query progress contains only the expected RocksDB store custom metrics") {
// fails if any new custom metrics are added to remind the author of API changes
import testImplicits._
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
index c93d0f0..c06a9ec 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
@@ -87,8 +87,9 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid
queryRunId = UUID.randomUUID, operatorId = 0, storeVersion = 0, numPartitions = 5)
// Create state store in a task and get the RocksDBConf from the instantiated RocksDB instance
+ // FIXME: event time column?
val rocksDBConfInTask: RocksDBConf = testRDD.mapPartitionsWithStateStore[RocksDBConf](
- spark.sqlContext, testStateInfo, testSchema, testSchema, 0) {
+ spark.sqlContext, testStateInfo, testSchema, testSchema, StatefulOperatorContext()) {
(store: StateStore, _: Iterator[String]) =>
// Use reflection to get RocksDB instance
val dbInstanceMethod =
@@ -144,7 +145,8 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid
numColsPrefixKey: Int): RocksDBStateStoreProvider = {
val provider = new RocksDBStateStoreProvider()
provider.init(
- storeId, keySchema, valueSchema, numColsPrefixKey = numColsPrefixKey,
+ storeId, keySchema, valueSchema,
+ StatefulOperatorContext(numColsPrefixKey = numColsPrefixKey),
new StateStoreConf, new Configuration)
provider
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
index 6bb8ebe..0e9b19d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
@@ -60,13 +60,13 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString
val rdd1 = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("a", 0)))
.mapPartitionsWithStateStore(spark.sqlContext, operatorStateInfo(path, version = 0),
- keySchema, valueSchema, numColsPrefixKey = 0)(increment)
+ keySchema, valueSchema, StatefulOperatorContext())(increment)
assert(rdd1.collect().toSet === Set(("a", 0) -> 2, ("b", 0) -> 1))
// Generate next version of stores
val rdd2 = makeRDD(spark.sparkContext, Seq(("a", 0), ("c", 0)))
.mapPartitionsWithStateStore(spark.sqlContext, operatorStateInfo(path, version = 1),
- keySchema, valueSchema, numColsPrefixKey = 0)(increment)
+ keySchema, valueSchema, StatefulOperatorContext())(increment)
assert(rdd2.collect().toSet === Set(("a", 0) -> 3, ("b", 0) -> 1, ("c", 0) -> 1))
// Make sure the previous RDD still has the same data.
@@ -84,7 +84,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
implicit val sqlContext = spark.sqlContext
makeRDD(spark.sparkContext, Seq(("a", 0))).mapPartitionsWithStateStore(
sqlContext, operatorStateInfo(path, version = storeVersion),
- keySchema, valueSchema, numColsPrefixKey = 0)(increment)
+ keySchema, valueSchema, StatefulOperatorContext())(increment)
}
// Generate RDDs and state store data
@@ -134,19 +134,19 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
val rddOfGets1 = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("c", 0)))
.mapPartitionsWithStateStore(spark.sqlContext, operatorStateInfo(path, version = 0),
- keySchema, valueSchema, numColsPrefixKey = 0)(iteratorOfGets)
+ keySchema, valueSchema, StatefulOperatorContext())(iteratorOfGets)
assert(rddOfGets1.collect().toSet ===
Set(("a", 0) -> None, ("b", 0) -> None, ("c", 0) -> None))
val rddOfPuts = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("a", 0)))
.mapPartitionsWithStateStore(sqlContext, operatorStateInfo(path, version = 0),
- keySchema, valueSchema, numColsPrefixKey = 0)(iteratorOfPuts)
+ keySchema, valueSchema, StatefulOperatorContext())(iteratorOfPuts)
assert(rddOfPuts.collect().toSet ===
Set(("a", 0) -> 1, ("a", 0) -> 2, ("b", 0) -> 1))
val rddOfGets2 = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("c", 0)))
.mapPartitionsWithStateStore(sqlContext, operatorStateInfo(path, version = 1),
- keySchema, valueSchema, numColsPrefixKey = 0)(iteratorOfGets)
+ keySchema, valueSchema, StatefulOperatorContext())(iteratorOfGets)
assert(rddOfGets2.collect().toSet ===
Set(("a", 0) -> Some(2), ("b", 0) -> Some(1), ("c", 0) -> None))
}
@@ -172,7 +172,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
val rdd = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("a", 0)))
.mapPartitionsWithStateStore(sqlContext, operatorStateInfo(path, queryRunId = queryRunId),
- keySchema, valueSchema, numColsPrefixKey = 0)(increment)
+ keySchema, valueSchema, StatefulOperatorContext())(increment)
require(rdd.partitions.length === 2)
assert(
@@ -200,13 +200,13 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
val opId = 0
val rdd1 = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("a", 0)))
.mapPartitionsWithStateStore(sqlContext, operatorStateInfo(path, version = 0),
- keySchema, valueSchema, numColsPrefixKey = 0)(increment)
+ keySchema, valueSchema, StatefulOperatorContext())(increment)
assert(rdd1.collect().toSet === Set(("a", 0) -> 2, ("b", 0) -> 1))
// Generate next version of stores
val rdd2 = makeRDD(spark.sparkContext, Seq(("a", 0), ("c", 0)))
.mapPartitionsWithStateStore(sqlContext, operatorStateInfo(path, version = 1),
- keySchema, valueSchema, numColsPrefixKey = 0)(increment)
+ keySchema, valueSchema, StatefulOperatorContext())(increment)
assert(rdd2.collect().toSet === Set(("a", 0) -> 3, ("b", 0) -> 1, ("c", 0) -> 1))
// Make sure the previous RDD still has the same data.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index 601b62b..d89c6fed 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -270,8 +270,8 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
def generateStoreVersions(): Unit = {
for (i <- 1 to 20) {
- val store = StateStore.get(storeProviderId1, keySchema, valueSchema, numColsPrefixKey = 0,
- latestStoreVersion, storeConf, hadoopConf)
+ val store = StateStore.get(storeProviderId1, keySchema, valueSchema,
+ StatefulOperatorContext(), latestStoreVersion, storeConf, hadoopConf)
put(store, "a", 0, i)
store.commit()
latestStoreVersion += 1
@@ -324,7 +324,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
}
// Reload the store and verify
- StateStore.get(storeProviderId1, keySchema, valueSchema, numColsPrefixKey = 0,
+ StateStore.get(storeProviderId1, keySchema, valueSchema, StatefulOperatorContext(),
latestStoreVersion, storeConf, hadoopConf)
assert(StateStore.isLoaded(storeProviderId1))
@@ -336,7 +336,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
}
// Reload the store and verify
- StateStore.get(storeProviderId1, keySchema, valueSchema, numColsPrefixKey = 0,
+ StateStore.get(storeProviderId1, keySchema, valueSchema, StatefulOperatorContext(),
latestStoreVersion, storeConf, hadoopConf)
assert(StateStore.isLoaded(storeProviderId1))
@@ -344,7 +344,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
// then this executor should unload inactive instances immediately.
coordinatorRef
.reportActiveInstance(storeProviderId1, "other-host", "other-exec", Seq.empty)
- StateStore.get(storeProviderId2, keySchema, valueSchema, numColsPrefixKey = 0,
+ StateStore.get(storeProviderId2, keySchema, valueSchema, StatefulOperatorContext(),
0, storeConf, hadoopConf)
assert(!StateStore.isLoaded(storeProviderId1))
assert(StateStore.isLoaded(storeProviderId2))
@@ -453,7 +453,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
// Getting the store should not create temp file
val store0 = shouldNotCreateTempFile {
StateStore.get(
- storeId, keySchema, valueSchema, numColsPrefixKey = 0,
+ storeId, keySchema, valueSchema, StatefulOperatorContext(),
version = 0, storeConf, hadoopConf)
}
@@ -470,7 +470,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
// Remove should create a temp file
val store1 = shouldNotCreateTempFile {
StateStore.get(
- storeId, keySchema, valueSchema, numColsPrefixKey = 0,
+ storeId, keySchema, valueSchema, StatefulOperatorContext(),
version = 1, storeConf, hadoopConf)
}
remove(store1, _._1 == "a")
@@ -485,7 +485,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
// Commit without any updates should create a delta file
val store2 = shouldNotCreateTempFile {
StateStore.get(
- storeId, keySchema, valueSchema, numColsPrefixKey = 0,
+ storeId, keySchema, valueSchema, StatefulOperatorContext(),
version = 2, storeConf, hadoopConf)
}
store2.commit()
@@ -720,11 +720,12 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
hadoopConf: Configuration = new Configuration): HDFSBackedStateStoreProvider = {
val sqlConf = getDefaultSQLConf(minDeltasForSnapshot, numOfVersToRetainInMemory)
val provider = new HDFSBackedStateStoreProvider()
+ // FIXME: event time column?
provider.init(
StateStoreId(dir, opId, partition),
keySchema,
valueSchema,
- numColsPrefixKey = numColsPrefixKey,
+ StatefulOperatorContext(numColsPrefixKey = numColsPrefixKey),
new StateStoreConf(sqlConf),
hadoopConf)
provider
@@ -1027,31 +1028,31 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
// Verify that trying to get incorrect versions throw errors
intercept[IllegalArgumentException] {
StateStore.get(
- storeId, keySchema, valueSchema, 0, -1, storeConf, hadoopConf)
+ storeId, keySchema, valueSchema, StatefulOperatorContext(), -1, storeConf, hadoopConf)
}
assert(!StateStore.isLoaded(storeId)) // version -1 should not attempt to load the store
intercept[IllegalStateException] {
StateStore.get(
- storeId, keySchema, valueSchema, 0, 1, storeConf, hadoopConf)
+ storeId, keySchema, valueSchema, StatefulOperatorContext(), 1, storeConf, hadoopConf)
}
// Increase version of the store and try to get again
val store0 = StateStore.get(
- storeId, keySchema, valueSchema, 0, 0, storeConf, hadoopConf)
+ storeId, keySchema, valueSchema, StatefulOperatorContext(), 0, storeConf, hadoopConf)
assert(store0.version === 0)
put(store0, "a", 0, 1)
store0.commit()
val store1 = StateStore.get(
- storeId, keySchema, valueSchema, 0, 1, storeConf, hadoopConf)
+ storeId, keySchema, valueSchema, StatefulOperatorContext(), 1, storeConf, hadoopConf)
assert(StateStore.isLoaded(storeId))
assert(store1.version === 1)
assert(rowPairsToDataSet(store1.iterator()) === Set(("a", 0) -> 1))
// Verify that you can also load older version
val store0reloaded = StateStore.get(
- storeId, keySchema, valueSchema, 0, 0, storeConf, hadoopConf)
+ storeId, keySchema, valueSchema, StatefulOperatorContext(), 0, storeConf, hadoopConf)
assert(store0reloaded.version === 0)
assert(rowPairsToDataSet(store0reloaded.iterator()) === Set.empty)
@@ -1060,7 +1061,7 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
assert(!StateStore.isLoaded(storeId))
val store1reloaded = StateStore.get(
- storeId, keySchema, valueSchema, 0, 1, storeConf, hadoopConf)
+ storeId, keySchema, valueSchema, StatefulOperatorContext(), 1, storeConf, hadoopConf)
assert(StateStore.isLoaded(storeId))
assert(store1reloaded.version === 1)
put(store1reloaded, "a", 0, 2)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManagerSuite.scala
index 096c3bb..dcdbac9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManagerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManagerSuite.scala
@@ -181,9 +181,11 @@ class StreamingSessionWindowStateManagerSuite extends StreamTest with BeforeAndA
stateFormatVersion)
val storeProviderId = StateStoreProviderId(stateInfo, 0, StateStoreId.DEFAULT_STORE_NAME)
+ // FIXME: event time column?
val store = StateStore.get(
storeProviderId, manager.getStateKeySchema, manager.getStateValueSchema,
- manager.getNumColsForPrefixKey, stateInfo.storeVersion, storeConf, new Configuration)
+ StatefulOperatorContext(numColsPrefixKey = manager.getNumColsForPrefixKey),
+ stateInfo.storeVersion, storeConf, new Configuration)
try {
f(manager, store)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index e89197b..d376942 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.{LocalLimitExec, SimpleMode, SparkPlan}
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.{ContinuousMemoryStream, MemorySink}
-import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider}
+import org.apache.spark.sql.execution.streaming.state.{StatefulOperatorContext, StateStore, StateStoreConf, StateStoreId, StateStoreProvider}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.StreamSourceProvider
@@ -1418,7 +1418,7 @@ class TestStateStoreProvider extends StateStoreProvider {
stateStoreId: StateStoreId,
keySchema: StructType,
valueSchema: StructType,
- numColsPrefixKey: Int,
+ operatorContext: StatefulOperatorContext,
storeConfs: StateStoreConf,
hadoopConf: Configuration): Unit = {
throw new Exception("Successfully instantiated")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
index 77334ad..fbf3aae 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.exchange.Exchange
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.MemorySink
-import org.apache.spark.sql.execution.streaming.state.{StateSchemaNotCompatible, StateStore, StreamingAggregationStateManager}
+import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider, StateSchemaNotCompatible, StateStore, StreamingAggregationStateManager}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.OutputMode._
@@ -53,29 +53,47 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions {
import testImplicits._
def executeFuncWithStateVersionSQLConf(
+ providerCls: String,
stateVersion: Int,
confPairs: Seq[(String, String)],
func: => Any): Unit = {
withSQLConf(confPairs ++
- Seq(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> stateVersion.toString): _*) {
+ Seq(
+ SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> stateVersion.toString,
+ SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerCls.stripSuffix("$")): _*) {
func
}
}
def testWithAllStateVersions(name: String, confPairs: (String, String)*)
(func: => Any): Unit = {
- for (version <- StreamingAggregationStateManager.supportedVersions) {
- test(s"$name - state format version $version") {
- executeFuncWithStateVersionSQLConf(version, confPairs, func)
+ val providers = Seq(
+ // FIXME: testing...
+ // classOf[HDFSBackedStateStoreProvider].getCanonicalName,
+ classOf[RocksDBStateStoreProvider].getCanonicalName)
+
+ for (
+ version <- StreamingAggregationStateManager.supportedVersions;
+ provider <- providers
+ ) yield {
+ test(s"$name - state format version $version / provider: $provider") {
+ executeFuncWithStateVersionSQLConf(provider, version, confPairs, func)
}
}
}
def testQuietlyWithAllStateVersions(name: String, confPairs: (String, String)*)
(func: => Any): Unit = {
- for (version <- StreamingAggregationStateManager.supportedVersions) {
- testQuietly(s"$name - state format version $version") {
- executeFuncWithStateVersionSQLConf(version, confPairs, func)
+ val providers = Seq(
+ classOf[HDFSBackedStateStoreProvider].getCanonicalName,
+ classOf[RocksDBStateStoreProvider].getCanonicalName)
+
+ for (
+ version <- StreamingAggregationStateManager.supportedVersions;
+ provider <- providers
+ ) yield {
+ testQuietly(s"$name - state format version $version / provider: $provider") {
+ executeFuncWithStateVersionSQLConf(provider, version, confPairs, func)
}
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org