You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ka...@apache.org on 2021/07/08 12:03:53 UTC
[spark] branch branch-3.2 updated: [SPARK-35988][SS] The
implementation for RocksDBStateStoreProvider
This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.2 by this push:
new 097b667 [SPARK-35988][SS] The implementation for RocksDBStateStoreProvider
097b667 is described below
commit 097b667db77a5ba048bd8b83e1c9509dc11975fb
Author: Yuanjian Li <yu...@databricks.com>
AuthorDate: Thu Jul 8 21:02:37 2021 +0900
[SPARK-35988][SS] The implementation for RocksDBStateStoreProvider
### What changes were proposed in this pull request?
Add the implementation for the RocksDBStateStoreProvider. It's the subclass of StateStoreProvider that leverages all the functionalities implemented in the RocksDB instance.
### Why are the changes needed?
The interface for the end-user to use the RocksDB state store.
### Does this PR introduce _any_ user-facing change?
Yes. New RocksDBStateStore can be used in their applications.
### How was this patch tested?
New UT added.
Closes #33187 from xuanyuanking/SPARK-35988.
Authored-by: Yuanjian Li <yu...@databricks.com>
Signed-off-by: Jungtaek Lim <ka...@gmail.com>
(cherry picked from commit 0621e78b5f4fdb0b7235fdd59220e3afb0d36eb6)
Signed-off-by: Jungtaek Lim <ka...@gmail.com>
---
.../state/RocksDBStateStoreProvider.scala | 331 +++++++++++++++++++++
.../state/RocksDBStateStoreIntegrationSuite.scala | 51 ++++
.../streaming/state/RocksDBStateStoreSuite.scala | 152 ++++++++++
.../streaming/state/StateStoreSuite.scala | 314 +++++++++----------
4 files changed, 691 insertions(+), 157 deletions(-)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
new file mode 100644
index 0000000..3ebaa8c
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
@@ -0,0 +1,331 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import java.io._
+
+import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.{SparkConf, SparkEnv}
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.unsafe.Platform
+import org.apache.spark.util.Utils
+
+private[state] class RocksDBStateStoreProvider
+ extends StateStoreProvider with Logging with Closeable {
+ import RocksDBStateStoreProvider._
+
+ class RocksDBStateStore(lastVersion: Long) extends StateStore {
+ /** Trait and classes representing the internal state of the store */
+ trait STATE
+ case object UPDATING extends STATE
+ case object COMMITTED extends STATE
+ case object ABORTED extends STATE
+
+ @volatile private var state: STATE = UPDATING
+ @volatile private var isValidated = false
+
+ override def id: StateStoreId = RocksDBStateStoreProvider.this.stateStoreId
+
+ override def version: Long = lastVersion
+
+ override def get(key: UnsafeRow): UnsafeRow = {
+ verify(key != null, "Key cannot be null")
+ val value = encoder.decodeValue(rocksDB.get(encoder.encode(key)))
+ if (!isValidated && value != null) {
+ StateStoreProvider.validateStateRowFormat(
+ key, keySchema, value, valueSchema, storeConf)
+ isValidated = true
+ }
+ value
+ }
+
+ override def put(key: UnsafeRow, value: UnsafeRow): Unit = {
+ verify(state == UPDATING, "Cannot put after already committed or aborted")
+ verify(key != null, "Key cannot be null")
+ require(value != null, "Cannot put a null value")
+ rocksDB.put(encoder.encode(key), encoder.encode(value))
+ }
+
+ override def remove(key: UnsafeRow): Unit = {
+ verify(state == UPDATING, "Cannot remove after already committed or aborted")
+ verify(key != null, "Key cannot be null")
+ rocksDB.remove(encoder.encode(key))
+ }
+
+ override def getRange(
+ start: Option[UnsafeRow],
+ end: Option[UnsafeRow]): Iterator[UnsafeRowPair] = {
+ verify(state == UPDATING, "Cannot call getRange() after already committed or aborted")
+ iterator()
+ }
+
+ override def iterator(): Iterator[UnsafeRowPair] = {
+ rocksDB.iterator().map { kv =>
+ val rowPair = encoder.decode(kv)
+ if (!isValidated && rowPair.value != null) {
+ StateStoreProvider.validateStateRowFormat(
+ rowPair.key, keySchema, rowPair.value, valueSchema, storeConf)
+ isValidated = true
+ }
+ rowPair
+ }
+ }
+
+ override def commit(): Long = synchronized {
+ verify(state == UPDATING, "Cannot commit after already committed or aborted")
+ val newVersion = rocksDB.commit()
+ state = COMMITTED
+ logInfo(s"Committed $newVersion for $id")
+ newVersion
+ }
+
+ override def abort(): Unit = {
+ verify(state == UPDATING || state == ABORTED, "Cannot abort after already committed")
+ logInfo(s"Aborting ${version + 1} for $id")
+ rocksDB.rollback()
+ state = ABORTED
+ }
+
+ override def metrics: StateStoreMetrics = {
+ val rocksDBMetrics = rocksDB.metrics
+ def commitLatencyMs(typ: String): Long = rocksDBMetrics.lastCommitLatencyMs.getOrElse(typ, 0L)
+ def avgNativeOpsLatencyMs(typ: String): Long = {
+ rocksDBMetrics.nativeOpsLatencyMicros.get(typ).map(_.avg).getOrElse(0.0).toLong
+ }
+
+ val stateStoreCustomMetrics = Map[StateStoreCustomMetric, Long](
+ CUSTOM_METRIC_SST_FILE_SIZE -> rocksDBMetrics.totalSSTFilesBytes,
+ CUSTOM_METRIC_GET_TIME -> avgNativeOpsLatencyMs("get"),
+ CUSTOM_METRIC_PUT_TIME -> avgNativeOpsLatencyMs("put"),
+ CUSTOM_METRIC_WRITEBATCH_TIME -> commitLatencyMs("writeBatch"),
+ CUSTOM_METRIC_FLUSH_TIME -> commitLatencyMs("flush"),
+ CUSTOM_METRIC_PAUSE_TIME -> commitLatencyMs("pause"),
+ CUSTOM_METRIC_CHECKPOINT_TIME -> commitLatencyMs("checkpoint"),
+ CUSTOM_METRIC_FILESYNC_TIME -> commitLatencyMs("fileSync"),
+ CUSTOM_METRIC_BYTES_COPIED -> rocksDBMetrics.bytesCopied,
+ CUSTOM_METRIC_FILES_COPIED -> rocksDBMetrics.filesCopied,
+ CUSTOM_METRIC_FILES_REUSED -> rocksDBMetrics.filesReused
+ ) ++ rocksDBMetrics.zipFileBytesUncompressed.map(bytes =>
+ Map(CUSTOM_METRIC_ZIP_FILE_BYTES_UNCOMPRESSED -> bytes)).getOrElse(Map())
+
+ StateStoreMetrics(
+ rocksDBMetrics.numUncommittedKeys,
+ rocksDBMetrics.memUsageBytes,
+ stateStoreCustomMetrics)
+ }
+
+ override def hasCommitted: Boolean = state == COMMITTED
+
+ override def toString: String = {
+ s"RocksDBStateStore[id=(op=${id.operatorId},part=${id.partitionId})," +
+ s"dir=${id.storeCheckpointLocation()}]"
+ }
+
+ /** Return the [[RocksDB]] instance in this store. This is exposed mainly for testing. */
+ def dbInstance(): RocksDB = rocksDB
+ }
+
+ override def init(
+ stateStoreId: StateStoreId,
+ keySchema: StructType,
+ valueSchema: StructType,
+ indexOrdinal: Option[Int],
+ storeConf: StateStoreConf,
+ hadoopConf: Configuration): Unit = {
+ this.stateStoreId_ = stateStoreId
+ this.keySchema = keySchema
+ this.valueSchema = valueSchema
+ this.storeConf = storeConf
+ this.hadoopConf = hadoopConf
+ rocksDB // lazy initialization
+ }
+
+ override def stateStoreId: StateStoreId = stateStoreId_
+
+ override def getStore(version: Long): StateStore = {
+ require(version >= 0, "Version cannot be less than 0")
+ rocksDB.load(version)
+ new RocksDBStateStore(version)
+ }
+
+ override def doMaintenance(): Unit = {
+ rocksDB.cleanup()
+ }
+
+ override def close(): Unit = {
+ rocksDB.close()
+ }
+
+ override def supportedCustomMetrics: Seq[StateStoreCustomMetric] = ALL_CUSTOM_METRICS
+
+ private[state] def latestVersion: Long = rocksDB.getLatestVersion()
+
+ /** Internal fields and methods */
+
+ @volatile private var stateStoreId_ : StateStoreId = _
+ @volatile private var keySchema: StructType = _
+ @volatile private var valueSchema: StructType = _
+ @volatile private var storeConf: StateStoreConf = _
+ @volatile private var hadoopConf: Configuration = _
+
+ private[sql] lazy val rocksDB = {
+ val dfsRootDir = stateStoreId.storeCheckpointLocation().toString
+ val storeIdStr = s"StateStoreId(opId=${stateStoreId.operatorId}," +
+ s"partId=${stateStoreId.partitionId},name=${stateStoreId.storeName})"
+ val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
+ val localRootDir = Utils.createTempDir(Utils.getLocalDir(sparkConf), storeIdStr)
+ new RocksDB(dfsRootDir, RocksDBConf(storeConf), localRootDir, hadoopConf, storeIdStr)
+ }
+
+ private lazy val encoder = new StateEncoder
+
+ private def verify(condition: => Boolean, msg: String): Unit = {
+ if (!condition) { throw new IllegalStateException(msg) }
+ }
+
+ /**
+ * Encodes/decodes UnsafeRows to versioned byte arrays.
+ * It uses the first byte of the generated byte array to store the version that describes how the
+ * row is encoded in the rest of the byte array. Currently, the default version is 0,
+ *
+ * VERSION 0: [ VERSION (1 byte) | ROW (N bytes) ]
+ * The bytes of a UnsafeRow is written unmodified to starting from offset 1
+ * (offset 0 is the version byte of value 0). That is, if the unsafe row has N bytes,
+ * then the generated array byte will be N+1 bytes.
+ */
+ class StateEncoder {
+ import RocksDBStateStoreProvider._
+
+ // Reusable objects
+ private val keyRow = new UnsafeRow(keySchema.size)
+ private val valueRow = new UnsafeRow(valueSchema.size)
+ private val rowTuple = new UnsafeRowPair()
+
+ /**
+ * Encode the UnsafeRow of N bytes as a N+1 byte array.
+ * @note This creates a new byte array and memcopies the UnsafeRow to the new array.
+ */
+ def encode(row: UnsafeRow): Array[Byte] = {
+ val bytesToEncode = row.getBytes
+ val encodedBytes = new Array[Byte](bytesToEncode.length + STATE_ENCODING_NUM_VERSION_BYTES)
+ Platform.putByte(encodedBytes, Platform.BYTE_ARRAY_OFFSET, STATE_ENCODING_VERSION)
+ // Platform.BYTE_ARRAY_OFFSET is the recommended way to memcopy b/w byte arrays. See Platform.
+ Platform.copyMemory(
+ bytesToEncode, Platform.BYTE_ARRAY_OFFSET,
+ encodedBytes, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES,
+ bytesToEncode.length)
+ encodedBytes
+ }
+
+ /**
+ * Decode byte array for a key to a UnsafeRow.
+ * @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to
+ * the given byte array.
+ */
+ def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
+ if (keyBytes != null) {
+ // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st offset. See Platform.
+ keyRow.pointTo(
+ keyBytes,
+ Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES,
+ keyBytes.length - STATE_ENCODING_NUM_VERSION_BYTES)
+ keyRow
+ } else {
+ null
+ }
+ }
+
+ /**
+ * Decode byte array for a value to a UnsafeRow.
+ *
+ * @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to
+ * the given byte array.
+ */
+ def decodeValue(valueBytes: Array[Byte]): UnsafeRow = {
+ if (valueBytes != null) {
+ // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st offset. See Platform.
+ valueRow.pointTo(
+ valueBytes,
+ Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES,
+ valueBytes.size - STATE_ENCODING_NUM_VERSION_BYTES)
+ valueRow
+ } else {
+ null
+ }
+ }
+
+ /**
+ * Decode pair of key-value byte arrays in a pair of key-value UnsafeRows.
+ *
+ * @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to
+ * the given byte array.
+ */
+ def decode(byteArrayTuple: ByteArrayPair): UnsafeRowPair = {
+ rowTuple.withRows(decodeKey(byteArrayTuple.key), decodeValue(byteArrayTuple.value))
+ }
+ }
+}
+
+object RocksDBStateStoreProvider {
+ // Version as a single byte that specifies the encoding of the row data in RocksDB
+ val STATE_ENCODING_NUM_VERSION_BYTES = 1
+ val STATE_ENCODING_VERSION: Byte = 0
+
+ // Native operation latencies report as latency per 1000 calls
+ // as SQLMetrics support ms latency whereas RocksDB reports it in microseconds.
+ val CUSTOM_METRIC_GET_TIME = StateStoreCustomTimingMetric(
+ "rocksdbGetLatency", "RocksDB: avg get latency (per 1000 calls)")
+ val CUSTOM_METRIC_PUT_TIME = StateStoreCustomTimingMetric(
+ "rocksdbPutLatency", "RocksDB: avg put latency (per 1000 calls)")
+
+ // Commit latency detailed breakdown
+ val CUSTOM_METRIC_WRITEBATCH_TIME = StateStoreCustomTimingMetric(
+ "rocksdbCommitWriteBatchLatency", "RocksDB: commit - write batch time")
+ val CUSTOM_METRIC_FLUSH_TIME = StateStoreCustomTimingMetric(
+ "rocksdbCommitFlushLatency", "RocksDB: commit - flush time")
+ val CUSTOM_METRIC_PAUSE_TIME = StateStoreCustomTimingMetric(
+ "rocksdbCommitPauseLatency", "RocksDB: commit - pause bg time")
+ val CUSTOM_METRIC_CHECKPOINT_TIME = StateStoreCustomTimingMetric(
+ "rocksdbCommitCheckpointLatency", "RocksDB: commit - checkpoint time")
+ val CUSTOM_METRIC_FILESYNC_TIME = StateStoreCustomTimingMetric(
+ "rocksdbFileSyncTime", "RocksDB: commit - file sync time")
+ val CUSTOM_METRIC_FILES_COPIED = StateStoreCustomSizeMetric(
+ "rocksdbFilesCopied", "RocksDB: file manager - files copied")
+ val CUSTOM_METRIC_BYTES_COPIED = StateStoreCustomSizeMetric(
+ "rocksdbBytesCopied", "RocksDB: file manager - bytes copied")
+ val CUSTOM_METRIC_FILES_REUSED = StateStoreCustomSizeMetric(
+ "rocksdbFilesReused", "RocksDB: file manager - files reused")
+ val CUSTOM_METRIC_ZIP_FILE_BYTES_UNCOMPRESSED = StateStoreCustomSizeMetric(
+ "rocksdbZipFileBytesUncompressed", "RocksDB: file manager - uncompressed zip file bytes")
+
+ // Total SST file size
+ val CUSTOM_METRIC_SST_FILE_SIZE = StateStoreCustomSizeMetric(
+ "rocksdbSstFileSize", "RocksDB: size of all SST files")
+
+ val ALL_CUSTOM_METRICS = Seq(
+ CUSTOM_METRIC_SST_FILE_SIZE, CUSTOM_METRIC_GET_TIME, CUSTOM_METRIC_PUT_TIME,
+ CUSTOM_METRIC_WRITEBATCH_TIME, CUSTOM_METRIC_FLUSH_TIME, CUSTOM_METRIC_PAUSE_TIME,
+ CUSTOM_METRIC_CHECKPOINT_TIME, CUSTOM_METRIC_FILESYNC_TIME,
+ CUSTOM_METRIC_BYTES_COPIED, CUSTOM_METRIC_FILES_COPIED, CUSTOM_METRIC_FILES_REUSED,
+ CUSTOM_METRIC_ZIP_FILE_BYTES_UNCOMPRESSED
+ )
+}
+
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
new file mode 100644
index 0000000..bf4bd3e
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import java.io.File
+
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming._
+
+
+class RocksDBStateStoreIntegrationSuite extends StreamTest {
+ import testImplicits._
+
+ test("RocksDBStateStore") {
+ withTempDir { dir =>
+ val input = MemoryStream[Int]
+ val conf = Map(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+ classOf[RocksDBStateStoreProvider].getName)
+
+ testStream(input.toDF.groupBy().count(), outputMode = OutputMode.Update)(
+ StartStream(checkpointLocation = dir.getAbsolutePath, additionalConfs = conf),
+ AddData(input, 1, 2, 3),
+ CheckAnswer(3),
+ AssertOnQuery { q =>
+ // Verify that RocksDBStateStore by verify the state checkpoints are [version].zip
+ val storeCheckpointDir = StateStoreId(
+ dir.getAbsolutePath + "/state", 0, 0).storeCheckpointLocation()
+ val storeCheckpointFile = storeCheckpointDir + "/1.zip"
+ new File(storeCheckpointFile).exists()
+ }
+ )
+ }
+ }
+}
+
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
new file mode 100644
index 0000000..b9cc844
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import java.util.UUID
+
+import scala.util.Random
+
+import org.apache.hadoop.conf.Configuration
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.LocalSparkSession.withSparkSession
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.Platform
+import org.apache.spark.util.Utils
+
+class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvider]
+ with BeforeAndAfter {
+
+ import StateStoreTestsHelper._
+
+ test("version encoding") {
+ import RocksDBStateStoreProvider._
+
+ val provider = newStoreProvider()
+ val store = provider.getStore(0)
+ val keyRow = stringToRow("a")
+ val valueRow = intToRow(1)
+ store.put(keyRow, valueRow)
+ val iter = provider.rocksDB.iterator()
+ assert(iter.hasNext)
+ val kv = iter.next()
+
+ // Verify the version encoded in first byte of the key and value byte arrays
+ assert(Platform.getByte(kv.key, Platform.BYTE_ARRAY_OFFSET) === STATE_ENCODING_VERSION)
+ assert(Platform.getByte(kv.value, Platform.BYTE_ARRAY_OFFSET) === STATE_ENCODING_VERSION)
+ }
+
+ test("RocksDB confs are passed correctly from SparkSession to db instance") {
+ val sparkConf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName)
+ withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
+ // Set the session confs that should be passed into RocksDB
+ val testConfs = Seq(
+ ("spark.sql.streaming.stateStore.providerClass",
+ classOf[RocksDBStateStoreProvider].getName),
+ (RocksDBConf.ROCKSDB_CONF_NAME_PREFIX + ".compactOnCommit", "true"),
+ (RocksDBConf.ROCKSDB_CONF_NAME_PREFIX + ".lockAcquireTimeoutMs", "10")
+ )
+ testConfs.foreach { case (k, v) => spark.conf.set(k, v) }
+
+ // Prepare test objects for running task on state store
+ val testRDD = spark.sparkContext.makeRDD[String](Seq("a"), 1)
+ val testSchema = StructType(Seq(StructField("key", StringType, true)))
+ val testStateInfo = StatefulOperatorStateInfo(
+ checkpointLocation = Utils.createTempDir().getAbsolutePath,
+ queryRunId = UUID.randomUUID, operatorId = 0, storeVersion = 0, numPartitions = 5)
+
+ // Create state store in a task and get the RocksDBConf from the instantiated RocksDB instance
+ val rocksDBConfInTask: RocksDBConf = testRDD.mapPartitionsWithStateStore[RocksDBConf](
+ spark.sqlContext, testStateInfo, testSchema, testSchema, None) {
+ (store: StateStore, _: Iterator[String]) =>
+ // Use reflection to get RocksDB instance
+ val dbInstanceMethod =
+ store.getClass.getMethods.filter(_.getName.contains("dbInstance")).head
+ Iterator(dbInstanceMethod.invoke(store).asInstanceOf[RocksDB].conf)
+ }.collect().head
+
+ // Verify the confs are same as those configured in the session conf
+ assert(rocksDBConfInTask.compactOnCommit == true)
+ assert(rocksDBConfInTask.lockAcquireTimeoutMs == 10L)
+ }
+ }
+
+ test("rocksdb file manager metrics exposed") {
+ import RocksDBStateStoreProvider._
+ def getCustomMetric(metrics: StateStoreMetrics, customMetric: StateStoreCustomMetric): Long = {
+ val metricPair = metrics.customMetrics.find(_._1.name == customMetric.name)
+ assert(metricPair.isDefined)
+ metricPair.get._2
+ }
+
+ val provider = newStoreProvider()
+ val store = provider.getStore(0)
+ // Verify state after updating
+ put(store, "a", 1)
+ assert(get(store, "a") === Some(1))
+ assert(store.commit() === 1)
+ assert(store.hasCommitted)
+ val storeMetrics = store.metrics
+ assert(storeMetrics.numKeys === 1)
+ assert(getCustomMetric(storeMetrics, CUSTOM_METRIC_FILES_COPIED) > 0L)
+ assert(getCustomMetric(storeMetrics, CUSTOM_METRIC_FILES_REUSED) == 0L)
+ assert(getCustomMetric(storeMetrics, CUSTOM_METRIC_BYTES_COPIED) > 0L)
+ assert(getCustomMetric(storeMetrics, CUSTOM_METRIC_ZIP_FILE_BYTES_UNCOMPRESSED) > 0L)
+ }
+
+ override def newStoreProvider(): RocksDBStateStoreProvider = {
+ newStoreProvider(StateStoreId(newDir(), Random.nextInt(), 0))
+ }
+
+ def newStoreProvider(storeId: StateStoreId): RocksDBStateStoreProvider = {
+ val keySchema = StructType(Seq(StructField("key", StringType, true)))
+ val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
+ val provider = new RocksDBStateStoreProvider()
+ provider.init(
+ storeId, keySchema, valueSchema, indexOrdinal = None, new StateStoreConf, new Configuration)
+ provider
+ }
+
+ override def getLatestData(storeProvider: RocksDBStateStoreProvider): Set[(String, Int)] = {
+ getData(storeProvider, version = -1)
+ }
+
+ override def getData(
+ provider: RocksDBStateStoreProvider,
+ version: Int = -1): Set[(String, Int)] = {
+ val reloadedProvider = newStoreProvider(provider.stateStoreId)
+ val versionToRead = if (version < 0) reloadedProvider.latestVersion else version
+ reloadedProvider.getStore(versionToRead).iterator().map(rowsToStringInt).toSet
+ }
+
+ override protected val keySchema = StructType(Seq(StructField("key", StringType, true)))
+ override protected val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
+
+ override def newStoreProvider(
+ minDeltasForSnapshot: Int,
+ numOfVersToRetainInMemory: Int): RocksDBStateStoreProvider = newStoreProvider()
+
+ override def getDefaultSQLConf(
+ minDeltasForSnapshot: Int,
+ numOfVersToRetainInMemory: Int): SQLConf = new SQLConf()
+}
+
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index 4323725..2990860 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -49,6 +49,7 @@ import org.apache.spark.util.Utils
class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
with BeforeAndAfter {
import StateStoreTestsHelper._
+ import StateStoreCoordinatorSuite._
override val keySchema = StructType(Seq(StructField("key", StringType, true)))
override val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
@@ -235,6 +236,162 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
assert(getSizeOfStateForCurrentVersion(store.metrics) > noDataMemoryUsed)
}
+ test("maintenance") {
+ val conf = new SparkConf()
+ .setMaster("local")
+ .setAppName("test")
+ // Make sure that when SparkContext stops, the StateStore maintenance thread 'quickly'
+ // fails to talk to the StateStoreCoordinator and unloads all the StateStores
+ .set(RPC_NUM_RETRIES, 1)
+ val opId = 0
+ val dir1 = newDir()
+ val storeProviderId1 = StateStoreProviderId(StateStoreId(dir1, opId, 0), UUID.randomUUID)
+ val dir2 = newDir()
+ val storeProviderId2 = StateStoreProviderId(StateStoreId(dir2, opId, 1), UUID.randomUUID)
+ val sqlConf = getDefaultSQLConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get,
+ SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get)
+ sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2)
+ // Make maintenance thread do snapshots and cleanups very fast
+ sqlConf.setConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL, 10L)
+ val storeConf = StateStoreConf(sqlConf)
+ val hadoopConf = new Configuration()
+ val provider = newStoreProvider(storeProviderId1.storeId)
+
+ var latestStoreVersion = 0
+
+ def generateStoreVersions(): Unit = {
+ for (i <- 1 to 20) {
+ val store = StateStore.get(storeProviderId1, keySchema, valueSchema, None,
+ latestStoreVersion, storeConf, hadoopConf)
+ put(store, "a", i)
+ store.commit()
+ latestStoreVersion += 1
+ }
+ }
+
+ val timeoutDuration = 1.minute
+
+ quietly {
+ withSpark(new SparkContext(conf)) { sc =>
+ withCoordinatorRef(sc) { coordinatorRef =>
+ require(!StateStore.isMaintenanceRunning, "StateStore is unexpectedly running")
+
+ // Generate sufficient versions of store for snapshots
+ generateStoreVersions()
+
+ eventually(timeout(timeoutDuration)) {
+ // Store should have been reported to the coordinator
+ assert(coordinatorRef.getLocation(storeProviderId1).nonEmpty,
+ "active instance was not reported")
+
+ // Background maintenance should clean up and generate snapshots
+ assert(StateStore.isMaintenanceRunning, "Maintenance task is not running")
+
+ // Some snapshots should have been generated
+ val snapshotVersions = (1 to latestStoreVersion).filter { version =>
+ fileExists(provider, version, isSnapshot = true)
+ }
+ assert(snapshotVersions.nonEmpty, "no snapshot file found")
+ }
+
+ // Generate more versions such that there is another snapshot and
+ // the earliest delta file will be cleaned up
+ generateStoreVersions()
+
+ // Earliest delta file should get cleaned up
+ eventually(timeout(timeoutDuration)) {
+ assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted")
+ }
+
+ // If driver decides to deactivate all stores related to a query run,
+ // then this instance should be unloaded
+ coordinatorRef.deactivateInstances(storeProviderId1.queryRunId)
+ eventually(timeout(timeoutDuration)) {
+ assert(!StateStore.isLoaded(storeProviderId1))
+ }
+
+ // Reload the store and verify
+ StateStore.get(storeProviderId1, keySchema, valueSchema, indexOrdinal = None,
+ latestStoreVersion, storeConf, hadoopConf)
+ assert(StateStore.isLoaded(storeProviderId1))
+
+ // If some other executor loads the store, then this instance should be unloaded
+ coordinatorRef
+ .reportActiveInstance(storeProviderId1, "other-host", "other-exec", Seq.empty)
+ eventually(timeout(timeoutDuration)) {
+ assert(!StateStore.isLoaded(storeProviderId1))
+ }
+
+ // Reload the store and verify
+ StateStore.get(storeProviderId1, keySchema, valueSchema, indexOrdinal = None,
+ latestStoreVersion, storeConf, hadoopConf)
+ assert(StateStore.isLoaded(storeProviderId1))
+
+ // If some other executor loads the store, and when this executor loads other store,
+ // then this executor should unload inactive instances immediately.
+ coordinatorRef
+ .reportActiveInstance(storeProviderId1, "other-host", "other-exec", Seq.empty)
+ StateStore.get(storeProviderId2, keySchema, valueSchema, indexOrdinal = None,
+ 0, storeConf, hadoopConf)
+ assert(!StateStore.isLoaded(storeProviderId1))
+ assert(StateStore.isLoaded(storeProviderId2))
+ }
+ }
+
+ // Verify if instance is unloaded if SparkContext is stopped
+ eventually(timeout(timeoutDuration)) {
+ require(SparkEnv.get === null)
+ assert(!StateStore.isLoaded(storeProviderId1))
+ assert(!StateStore.isLoaded(storeProviderId2))
+ assert(!StateStore.isMaintenanceRunning)
+ }
+ }
+ }
+
+ test("snapshotting") {
+ val provider = newStoreProvider(minDeltasForSnapshot = 5, numOfVersToRetainInMemory = 2)
+
+ var currentVersion = 0
+
+ currentVersion = updateVersionTo(provider, currentVersion, 2)
+ require(getLatestData(provider) === Set("a" -> 2))
+ provider.doMaintenance() // should not generate snapshot files
+ assert(getLatestData(provider) === Set("a" -> 2))
+
+ for (i <- 1 to currentVersion) {
+ assert(fileExists(provider, i, isSnapshot = false)) // all delta files present
+ assert(!fileExists(provider, i, isSnapshot = true)) // no snapshot files present
+ }
+
+ // After version 6, snapshotting should generate one snapshot file
+ currentVersion = updateVersionTo(provider, currentVersion, 6)
+ require(getLatestData(provider) === Set("a" -> 6), "store not updated correctly")
+ provider.doMaintenance() // should generate snapshot files
+
+ val snapshotVersion = (0 to 6).find(version => fileExists(provider, version, isSnapshot = true))
+ assert(snapshotVersion.nonEmpty, "snapshot file not generated")
+ deleteFilesEarlierThanVersion(provider, snapshotVersion.get)
+ assert(
+ getData(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get),
+ "snapshotting messed up the data of the snapshotted version")
+ assert(
+ getLatestData(provider) === Set("a" -> 6),
+ "snapshotting messed up the data of the final version")
+
+ // After version 20, snapshotting should generate newer snapshot files
+ currentVersion = updateVersionTo(provider, currentVersion, 20)
+ require(getLatestData(provider) === Set("a" -> 20), "store not updated correctly")
+ provider.doMaintenance() // do snapshot
+
+ val latestSnapshotVersion = (0 to 20).filter(version =>
+ fileExists(provider, version, isSnapshot = true)).lastOption
+ assert(latestSnapshotVersion.nonEmpty, "no snapshot file found")
+ assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated")
+
+ deleteFilesEarlierThanVersion(provider, latestSnapshotVersion.get)
+ assert(getLatestData(provider) === Set("a" -> 20), "snapshotting messed up the data")
+ }
+
testQuietly("SPARK-18342: commit fails when rename fails") {
import RenameReturnsFalseFileSystem._
val dir = scheme + "://" + newDir()
@@ -582,7 +739,6 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
extends StateStoreCodecsTest with PrivateMethodTester {
import StateStoreTestsHelper._
- import StateStoreCoordinatorSuite._
type MapType = mutable.HashMap[UnsafeRow, UnsafeRow]
type ProviderMapType = java.util.concurrent.ConcurrentHashMap[UnsafeRow, UnsafeRow]
@@ -761,118 +917,6 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
assert(rowsToSet(finalStore.iterator()) === Set(key -> 2))
}
- test("maintenance") {
- val conf = new SparkConf()
- .setMaster("local")
- .setAppName("test")
- // Make sure that when SparkContext stops, the StateStore maintenance thread 'quickly'
- // fails to talk to the StateStoreCoordinator and unloads all the StateStores
- .set(RPC_NUM_RETRIES, 1)
- val opId = 0
- val dir1 = newDir()
- val storeProviderId1 = StateStoreProviderId(StateStoreId(dir1, opId, 0), UUID.randomUUID)
- val dir2 = newDir()
- val storeProviderId2 = StateStoreProviderId(StateStoreId(dir2, opId, 1), UUID.randomUUID)
- val sqlConf = getDefaultSQLConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get,
- SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get)
- sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2)
- // Make maintenance thread do snapshots and cleanups very fast
- sqlConf.setConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL, 10L)
- val storeConf = StateStoreConf(sqlConf)
- val hadoopConf = new Configuration()
- val provider = newStoreProvider(storeProviderId1.storeId)
-
- var latestStoreVersion = 0
-
- def generateStoreVersions(): Unit = {
- for (i <- 1 to 20) {
- val store = StateStore.get(storeProviderId1, keySchema, valueSchema, None,
- latestStoreVersion, storeConf, hadoopConf)
- put(store, "a", i)
- store.commit()
- latestStoreVersion += 1
- }
- }
-
- val timeoutDuration = 1.minute
-
- quietly {
- withSpark(new SparkContext(conf)) { sc =>
- withCoordinatorRef(sc) { coordinatorRef =>
- require(!StateStore.isMaintenanceRunning, "StateStore is unexpectedly running")
-
- // Generate sufficient versions of store for snapshots
- generateStoreVersions()
-
- eventually(timeout(timeoutDuration)) {
- // Store should have been reported to the coordinator
- assert(coordinatorRef.getLocation(storeProviderId1).nonEmpty,
- "active instance was not reported")
-
- // Background maintenance should clean up and generate snapshots
- assert(StateStore.isMaintenanceRunning, "Maintenance task is not running")
-
- // Some snapshots should have been generated
- val snapshotVersions = (1 to latestStoreVersion).filter { version =>
- fileExists(provider, version, isSnapshot = true)
- }
- assert(snapshotVersions.nonEmpty, "no snapshot file found")
- }
-
- // Generate more versions such that there is another snapshot and
- // the earliest delta file will be cleaned up
- generateStoreVersions()
-
- // Earliest delta file should get cleaned up
- eventually(timeout(timeoutDuration)) {
- assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted")
- }
-
- // If driver decides to deactivate all stores related to a query run,
- // then this instance should be unloaded
- coordinatorRef.deactivateInstances(storeProviderId1.queryRunId)
- eventually(timeout(timeoutDuration)) {
- assert(!StateStore.isLoaded(storeProviderId1))
- }
-
- // Reload the store and verify
- StateStore.get(storeProviderId1, keySchema, valueSchema, indexOrdinal = None,
- latestStoreVersion, storeConf, hadoopConf)
- assert(StateStore.isLoaded(storeProviderId1))
-
- // If some other executor loads the store, then this instance should be unloaded
- coordinatorRef
- .reportActiveInstance(storeProviderId1, "other-host", "other-exec", Seq.empty)
- eventually(timeout(timeoutDuration)) {
- assert(!StateStore.isLoaded(storeProviderId1))
- }
-
- // Reload the store and verify
- StateStore.get(storeProviderId1, keySchema, valueSchema, indexOrdinal = None,
- latestStoreVersion, storeConf, hadoopConf)
- assert(StateStore.isLoaded(storeProviderId1))
-
- // If some other executor loads the store, and when this executor loads other store,
- // then this executor should unload inactive instances immediately.
- coordinatorRef
- .reportActiveInstance(storeProviderId1, "other-host", "other-exec", Seq.empty)
- StateStore.get(storeProviderId2, keySchema, valueSchema, indexOrdinal = None,
- 0, storeConf, hadoopConf)
- assert(!StateStore.isLoaded(storeProviderId1))
- assert(StateStore.isLoaded(storeProviderId2))
- }
- }
-
- // Verify if instance is unloaded if SparkContext is stopped
- eventually(timeout(timeoutDuration)) {
- require(SparkEnv.get === null)
- assert(!StateStore.isLoaded(storeProviderId1))
- assert(!StateStore.isLoaded(storeProviderId2))
- assert(!StateStore.isMaintenanceRunning)
- }
- }
- }
-
test("StateStore.get") {
quietly {
val dir = newDir()
@@ -925,50 +969,6 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
}
}
- test("snapshotting") {
- val provider = newStoreProvider(minDeltasForSnapshot = 5, numOfVersToRetainInMemory = 2)
-
- var currentVersion = 0
-
- currentVersion = updateVersionTo(provider, currentVersion, 2)
- require(getLatestData(provider) === Set("a" -> 2))
- provider.doMaintenance() // should not generate snapshot files
- assert(getLatestData(provider) === Set("a" -> 2))
-
- for (i <- 1 to currentVersion) {
- assert(fileExists(provider, i, isSnapshot = false)) // all delta files present
- assert(!fileExists(provider, i, isSnapshot = true)) // no snapshot files present
- }
-
- // After version 6, snapshotting should generate one snapshot file
- currentVersion = updateVersionTo(provider, currentVersion, 6)
- require(getLatestData(provider) === Set("a" -> 6), "store not updated correctly")
- provider.doMaintenance() // should generate snapshot files
-
- val snapshotVersion = (0 to 6).find(version => fileExists(provider, version, isSnapshot = true))
- assert(snapshotVersion.nonEmpty, "snapshot file not generated")
- deleteFilesEarlierThanVersion(provider, snapshotVersion.get)
- assert(
- getData(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get),
- "snapshotting messed up the data of the snapshotted version")
- assert(
- getLatestData(provider) === Set("a" -> 6),
- "snapshotting messed up the data of the final version")
-
- // After version 20, snapshotting should generate newer snapshot files
- currentVersion = updateVersionTo(provider, currentVersion, 20)
- require(getLatestData(provider) === Set("a" -> 20), "store not updated correctly")
- provider.doMaintenance() // do snapshot
-
- val latestSnapshotVersion = (0 to 20).filter(version =>
- fileExists(provider, version, isSnapshot = true)).lastOption
- assert(latestSnapshotVersion.nonEmpty, "no snapshot file found")
- assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated")
-
- deleteFilesEarlierThanVersion(provider, latestSnapshotVersion.get)
- assert(getLatestData(provider) === Set("a" -> 20), "snapshotting messed up the data")
- }
-
test("reports memory usage") {
val provider = newStoreProvider()
val store = provider.getStore(0)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org