You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ka...@apache.org on 2021/07/08 12:03:53 UTC

[spark] branch branch-3.2 updated: [SPARK-35988][SS] The implementation for RocksDBStateStoreProvider

This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new 097b667  [SPARK-35988][SS] The implementation for RocksDBStateStoreProvider
097b667 is described below

commit 097b667db77a5ba048bd8b83e1c9509dc11975fb
Author: Yuanjian Li <yu...@databricks.com>
AuthorDate: Thu Jul 8 21:02:37 2021 +0900

    [SPARK-35988][SS] The implementation for RocksDBStateStoreProvider
    
    ### What changes were proposed in this pull request?
    Add the implementation for the RocksDBStateStoreProvider. It's the subclass of StateStoreProvider that leverages all the functionalities implemented in the RocksDB instance.
    
    ### Why are the changes needed?
    The interface for the end-user to use the RocksDB state store.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. New RocksDBStateStore can be used in their applications.
    
    ### How was this patch tested?
    New UT added.
    
    Closes #33187 from xuanyuanking/SPARK-35988.
    
    Authored-by: Yuanjian Li <yu...@databricks.com>
    Signed-off-by: Jungtaek Lim <ka...@gmail.com>
    (cherry picked from commit 0621e78b5f4fdb0b7235fdd59220e3afb0d36eb6)
    Signed-off-by: Jungtaek Lim <ka...@gmail.com>
---
 .../state/RocksDBStateStoreProvider.scala          | 331 +++++++++++++++++++++
 .../state/RocksDBStateStoreIntegrationSuite.scala  |  51 ++++
 .../streaming/state/RocksDBStateStoreSuite.scala   | 152 ++++++++++
 .../streaming/state/StateStoreSuite.scala          | 314 +++++++++----------
 4 files changed, 691 insertions(+), 157 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
new file mode 100644
index 0000000..3ebaa8c
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
@@ -0,0 +1,331 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import java.io._
+
+import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.{SparkConf, SparkEnv}
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.unsafe.Platform
+import org.apache.spark.util.Utils
+
+private[state] class RocksDBStateStoreProvider
+  extends StateStoreProvider with Logging with Closeable {
+  import RocksDBStateStoreProvider._
+
+  class RocksDBStateStore(lastVersion: Long) extends StateStore {
+    /** Trait and classes representing the internal state of the store */
+    trait STATE
+    case object UPDATING extends STATE
+    case object COMMITTED extends STATE
+    case object ABORTED extends STATE
+
+    @volatile private var state: STATE = UPDATING
+    @volatile private var isValidated = false
+
+    override def id: StateStoreId = RocksDBStateStoreProvider.this.stateStoreId
+
+    override def version: Long = lastVersion
+
+    override def get(key: UnsafeRow): UnsafeRow = {
+      verify(key != null, "Key cannot be null")
+      val value = encoder.decodeValue(rocksDB.get(encoder.encode(key)))
+      if (!isValidated && value != null) {
+        StateStoreProvider.validateStateRowFormat(
+          key, keySchema, value, valueSchema, storeConf)
+        isValidated = true
+      }
+      value
+    }
+
+    override def put(key: UnsafeRow, value: UnsafeRow): Unit = {
+      verify(state == UPDATING, "Cannot put after already committed or aborted")
+      verify(key != null, "Key cannot be null")
+      require(value != null, "Cannot put a null value")
+      rocksDB.put(encoder.encode(key), encoder.encode(value))
+    }
+
+    override def remove(key: UnsafeRow): Unit = {
+      verify(state == UPDATING, "Cannot remove after already committed or aborted")
+      verify(key != null, "Key cannot be null")
+      rocksDB.remove(encoder.encode(key))
+    }
+
+    override def getRange(
+        start: Option[UnsafeRow],
+        end: Option[UnsafeRow]): Iterator[UnsafeRowPair] = {
+      verify(state == UPDATING, "Cannot call getRange() after already committed or aborted")
+      iterator()
+    }
+
+    override def iterator(): Iterator[UnsafeRowPair] = {
+      rocksDB.iterator().map { kv =>
+        val rowPair = encoder.decode(kv)
+        if (!isValidated && rowPair.value != null) {
+          StateStoreProvider.validateStateRowFormat(
+            rowPair.key, keySchema, rowPair.value, valueSchema, storeConf)
+          isValidated = true
+        }
+        rowPair
+      }
+    }
+
+    override def commit(): Long = synchronized {
+      verify(state == UPDATING, "Cannot commit after already committed or aborted")
+      val newVersion = rocksDB.commit()
+      state = COMMITTED
+      logInfo(s"Committed $newVersion for $id")
+      newVersion
+    }
+
+    override def abort(): Unit = {
+      verify(state == UPDATING || state == ABORTED, "Cannot abort after already committed")
+      logInfo(s"Aborting ${version + 1} for $id")
+      rocksDB.rollback()
+      state = ABORTED
+    }
+
+    override def metrics: StateStoreMetrics = {
+      val rocksDBMetrics = rocksDB.metrics
+      def commitLatencyMs(typ: String): Long = rocksDBMetrics.lastCommitLatencyMs.getOrElse(typ, 0L)
+      def avgNativeOpsLatencyMs(typ: String): Long = {
+        rocksDBMetrics.nativeOpsLatencyMicros.get(typ).map(_.avg).getOrElse(0.0).toLong
+      }
+
+      val stateStoreCustomMetrics = Map[StateStoreCustomMetric, Long](
+        CUSTOM_METRIC_SST_FILE_SIZE -> rocksDBMetrics.totalSSTFilesBytes,
+        CUSTOM_METRIC_GET_TIME -> avgNativeOpsLatencyMs("get"),
+        CUSTOM_METRIC_PUT_TIME -> avgNativeOpsLatencyMs("put"),
+        CUSTOM_METRIC_WRITEBATCH_TIME -> commitLatencyMs("writeBatch"),
+        CUSTOM_METRIC_FLUSH_TIME -> commitLatencyMs("flush"),
+        CUSTOM_METRIC_PAUSE_TIME -> commitLatencyMs("pause"),
+        CUSTOM_METRIC_CHECKPOINT_TIME -> commitLatencyMs("checkpoint"),
+        CUSTOM_METRIC_FILESYNC_TIME -> commitLatencyMs("fileSync"),
+        CUSTOM_METRIC_BYTES_COPIED -> rocksDBMetrics.bytesCopied,
+        CUSTOM_METRIC_FILES_COPIED -> rocksDBMetrics.filesCopied,
+        CUSTOM_METRIC_FILES_REUSED -> rocksDBMetrics.filesReused
+      ) ++ rocksDBMetrics.zipFileBytesUncompressed.map(bytes =>
+        Map(CUSTOM_METRIC_ZIP_FILE_BYTES_UNCOMPRESSED -> bytes)).getOrElse(Map())
+
+      StateStoreMetrics(
+        rocksDBMetrics.numUncommittedKeys,
+        rocksDBMetrics.memUsageBytes,
+        stateStoreCustomMetrics)
+    }
+
+    override def hasCommitted: Boolean = state == COMMITTED
+
+    override def toString: String = {
+      s"RocksDBStateStore[id=(op=${id.operatorId},part=${id.partitionId})," +
+        s"dir=${id.storeCheckpointLocation()}]"
+    }
+
+    /** Return the [[RocksDB]] instance in this store. This is exposed mainly for testing. */
+    def dbInstance(): RocksDB = rocksDB
+  }
+
+  override def init(
+      stateStoreId: StateStoreId,
+      keySchema: StructType,
+      valueSchema: StructType,
+      indexOrdinal: Option[Int],
+      storeConf: StateStoreConf,
+      hadoopConf: Configuration): Unit = {
+    this.stateStoreId_ = stateStoreId
+    this.keySchema = keySchema
+    this.valueSchema = valueSchema
+    this.storeConf = storeConf
+    this.hadoopConf = hadoopConf
+    rocksDB // lazy initialization
+  }
+
+  override def stateStoreId: StateStoreId = stateStoreId_
+
+  override def getStore(version: Long): StateStore = {
+    require(version >= 0, "Version cannot be less than 0")
+    rocksDB.load(version)
+    new RocksDBStateStore(version)
+  }
+
+  override def doMaintenance(): Unit = {
+    rocksDB.cleanup()
+  }
+
+  override def close(): Unit = {
+    rocksDB.close()
+  }
+
+  override def supportedCustomMetrics: Seq[StateStoreCustomMetric] = ALL_CUSTOM_METRICS
+
+  private[state] def latestVersion: Long = rocksDB.getLatestVersion()
+
+  /** Internal fields and methods */
+
+  @volatile private var stateStoreId_ : StateStoreId = _
+  @volatile private var keySchema: StructType = _
+  @volatile private var valueSchema: StructType = _
+  @volatile private var storeConf: StateStoreConf = _
+  @volatile private var hadoopConf: Configuration = _
+
+  private[sql] lazy val rocksDB = {
+    val dfsRootDir = stateStoreId.storeCheckpointLocation().toString
+    val storeIdStr = s"StateStoreId(opId=${stateStoreId.operatorId}," +
+      s"partId=${stateStoreId.partitionId},name=${stateStoreId.storeName})"
+    val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
+    val localRootDir = Utils.createTempDir(Utils.getLocalDir(sparkConf), storeIdStr)
+    new RocksDB(dfsRootDir, RocksDBConf(storeConf), localRootDir, hadoopConf, storeIdStr)
+  }
+
+  private lazy val encoder = new StateEncoder
+
+  private def verify(condition: => Boolean, msg: String): Unit = {
+    if (!condition) { throw new IllegalStateException(msg) }
+  }
+
+  /**
+   * Encodes/decodes UnsafeRows to versioned byte arrays.
+   * It uses the first byte of the generated byte array to store the version that describes how the
+   * row is encoded in the rest of the byte array. Currently, the default version is 0,
+   *
+   * VERSION 0:  [ VERSION (1 byte) | ROW (N bytes) ]
+   *    The bytes of a UnsafeRow is written unmodified to starting from offset 1
+   *    (offset 0 is the version byte of value 0). That is, if the unsafe row has N bytes,
+   *    then the generated array byte will be N+1 bytes.
+   */
+  class StateEncoder {
+    import RocksDBStateStoreProvider._
+
+    // Reusable objects
+    private val keyRow = new UnsafeRow(keySchema.size)
+    private val valueRow = new UnsafeRow(valueSchema.size)
+    private val rowTuple = new UnsafeRowPair()
+
+    /**
+     * Encode the UnsafeRow of N bytes as a N+1 byte array.
+     * @note This creates a new byte array and memcopies the UnsafeRow to the new array.
+     */
+    def encode(row: UnsafeRow): Array[Byte] = {
+      val bytesToEncode = row.getBytes
+      val encodedBytes = new Array[Byte](bytesToEncode.length + STATE_ENCODING_NUM_VERSION_BYTES)
+      Platform.putByte(encodedBytes, Platform.BYTE_ARRAY_OFFSET, STATE_ENCODING_VERSION)
+      // Platform.BYTE_ARRAY_OFFSET is the recommended way to memcopy b/w byte arrays. See Platform.
+      Platform.copyMemory(
+        bytesToEncode, Platform.BYTE_ARRAY_OFFSET,
+        encodedBytes, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES,
+        bytesToEncode.length)
+      encodedBytes
+    }
+
+    /**
+     * Decode byte array for a key to a UnsafeRow.
+     * @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to
+     *       the given byte array.
+     */
+    def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
+      if (keyBytes != null) {
+        // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st offset. See Platform.
+        keyRow.pointTo(
+          keyBytes,
+          Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES,
+          keyBytes.length - STATE_ENCODING_NUM_VERSION_BYTES)
+        keyRow
+      } else {
+        null
+      }
+    }
+
+    /**
+     * Decode byte array for a value to a UnsafeRow.
+     *
+     * @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to
+     *       the given byte array.
+     */
+    def decodeValue(valueBytes: Array[Byte]): UnsafeRow = {
+      if (valueBytes != null) {
+        // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st offset. See Platform.
+        valueRow.pointTo(
+          valueBytes,
+          Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES,
+          valueBytes.size - STATE_ENCODING_NUM_VERSION_BYTES)
+        valueRow
+      } else {
+        null
+      }
+    }
+
+    /**
+     * Decode pair of key-value byte arrays in a pair of key-value UnsafeRows.
+     *
+     * @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to
+     *       the given byte array.
+     */
+    def decode(byteArrayTuple: ByteArrayPair): UnsafeRowPair = {
+      rowTuple.withRows(decodeKey(byteArrayTuple.key), decodeValue(byteArrayTuple.value))
+    }
+  }
+}
+
+object RocksDBStateStoreProvider {
+  // Version as a single byte that specifies the encoding of the row data in RocksDB
+  val STATE_ENCODING_NUM_VERSION_BYTES = 1
+  val STATE_ENCODING_VERSION: Byte = 0
+
+  // Native operation latencies report as latency per 1000 calls
+  // as SQLMetrics support ms latency whereas RocksDB reports it in microseconds.
+  val CUSTOM_METRIC_GET_TIME = StateStoreCustomTimingMetric(
+    "rocksdbGetLatency", "RocksDB: avg get latency (per 1000 calls)")
+  val CUSTOM_METRIC_PUT_TIME = StateStoreCustomTimingMetric(
+    "rocksdbPutLatency", "RocksDB: avg put latency (per 1000 calls)")
+
+  // Commit latency detailed breakdown
+  val CUSTOM_METRIC_WRITEBATCH_TIME = StateStoreCustomTimingMetric(
+    "rocksdbCommitWriteBatchLatency", "RocksDB: commit - write batch time")
+  val CUSTOM_METRIC_FLUSH_TIME = StateStoreCustomTimingMetric(
+    "rocksdbCommitFlushLatency", "RocksDB: commit - flush time")
+  val CUSTOM_METRIC_PAUSE_TIME = StateStoreCustomTimingMetric(
+    "rocksdbCommitPauseLatency", "RocksDB: commit - pause bg time")
+  val CUSTOM_METRIC_CHECKPOINT_TIME = StateStoreCustomTimingMetric(
+    "rocksdbCommitCheckpointLatency", "RocksDB: commit - checkpoint time")
+  val CUSTOM_METRIC_FILESYNC_TIME = StateStoreCustomTimingMetric(
+    "rocksdbFileSyncTime", "RocksDB: commit - file sync time")
+  val CUSTOM_METRIC_FILES_COPIED = StateStoreCustomSizeMetric(
+    "rocksdbFilesCopied", "RocksDB: file manager - files copied")
+  val CUSTOM_METRIC_BYTES_COPIED = StateStoreCustomSizeMetric(
+    "rocksdbBytesCopied", "RocksDB: file manager - bytes copied")
+  val CUSTOM_METRIC_FILES_REUSED = StateStoreCustomSizeMetric(
+    "rocksdbFilesReused", "RocksDB: file manager - files reused")
+  val CUSTOM_METRIC_ZIP_FILE_BYTES_UNCOMPRESSED = StateStoreCustomSizeMetric(
+    "rocksdbZipFileBytesUncompressed", "RocksDB: file manager - uncompressed zip file bytes")
+
+  // Total SST file size
+  val CUSTOM_METRIC_SST_FILE_SIZE = StateStoreCustomSizeMetric(
+    "rocksdbSstFileSize", "RocksDB: size of all SST files")
+
+  val ALL_CUSTOM_METRICS = Seq(
+    CUSTOM_METRIC_SST_FILE_SIZE, CUSTOM_METRIC_GET_TIME, CUSTOM_METRIC_PUT_TIME,
+    CUSTOM_METRIC_WRITEBATCH_TIME, CUSTOM_METRIC_FLUSH_TIME, CUSTOM_METRIC_PAUSE_TIME,
+    CUSTOM_METRIC_CHECKPOINT_TIME, CUSTOM_METRIC_FILESYNC_TIME,
+    CUSTOM_METRIC_BYTES_COPIED, CUSTOM_METRIC_FILES_COPIED, CUSTOM_METRIC_FILES_REUSED,
+    CUSTOM_METRIC_ZIP_FILE_BYTES_UNCOMPRESSED
+  )
+}
+
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
new file mode 100644
index 0000000..bf4bd3e
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import java.io.File
+
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.streaming._
+
+
+class RocksDBStateStoreIntegrationSuite extends StreamTest {
+  import testImplicits._
+
+  test("RocksDBStateStore") {
+    withTempDir { dir =>
+      val input = MemoryStream[Int]
+      val conf = Map(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+        classOf[RocksDBStateStoreProvider].getName)
+
+      testStream(input.toDF.groupBy().count(), outputMode = OutputMode.Update)(
+        StartStream(checkpointLocation = dir.getAbsolutePath, additionalConfs = conf),
+        AddData(input, 1, 2, 3),
+        CheckAnswer(3),
+        AssertOnQuery { q =>
+          // Verify that RocksDBStateStore by verify the state checkpoints are [version].zip
+          val storeCheckpointDir = StateStoreId(
+            dir.getAbsolutePath + "/state", 0, 0).storeCheckpointLocation()
+          val storeCheckpointFile = storeCheckpointDir + "/1.zip"
+          new File(storeCheckpointFile).exists()
+        }
+      )
+    }
+  }
+}
+
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
new file mode 100644
index 0000000..b9cc844
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import java.util.UUID
+
+import scala.util.Random
+
+import org.apache.hadoop.conf.Configuration
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.LocalSparkSession.withSparkSession
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.Platform
+import org.apache.spark.util.Utils
+
+class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvider]
+  with BeforeAndAfter {
+
+  import StateStoreTestsHelper._
+
+  test("version encoding") {
+    import RocksDBStateStoreProvider._
+
+    val provider = newStoreProvider()
+    val store = provider.getStore(0)
+    val keyRow = stringToRow("a")
+    val valueRow = intToRow(1)
+    store.put(keyRow, valueRow)
+    val iter = provider.rocksDB.iterator()
+    assert(iter.hasNext)
+    val kv = iter.next()
+
+    // Verify the version encoded in first byte of the key and value byte arrays
+    assert(Platform.getByte(kv.key, Platform.BYTE_ARRAY_OFFSET) === STATE_ENCODING_VERSION)
+    assert(Platform.getByte(kv.value, Platform.BYTE_ARRAY_OFFSET) === STATE_ENCODING_VERSION)
+  }
+
+  test("RocksDB confs are passed correctly from SparkSession to db instance") {
+    val sparkConf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName)
+    withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
+      // Set the session confs that should be passed into RocksDB
+      val testConfs = Seq(
+        ("spark.sql.streaming.stateStore.providerClass",
+          classOf[RocksDBStateStoreProvider].getName),
+        (RocksDBConf.ROCKSDB_CONF_NAME_PREFIX + ".compactOnCommit", "true"),
+        (RocksDBConf.ROCKSDB_CONF_NAME_PREFIX + ".lockAcquireTimeoutMs", "10")
+      )
+      testConfs.foreach { case (k, v) => spark.conf.set(k, v) }
+
+      // Prepare test objects for running task on state store
+      val testRDD = spark.sparkContext.makeRDD[String](Seq("a"), 1)
+      val testSchema = StructType(Seq(StructField("key", StringType, true)))
+      val testStateInfo = StatefulOperatorStateInfo(
+        checkpointLocation = Utils.createTempDir().getAbsolutePath,
+        queryRunId = UUID.randomUUID, operatorId = 0, storeVersion = 0, numPartitions = 5)
+
+      // Create state store in a task and get the RocksDBConf from the instantiated RocksDB instance
+      val rocksDBConfInTask: RocksDBConf = testRDD.mapPartitionsWithStateStore[RocksDBConf](
+        spark.sqlContext, testStateInfo, testSchema, testSchema, None) {
+          (store: StateStore, _: Iterator[String]) =>
+            // Use reflection to get RocksDB instance
+            val dbInstanceMethod =
+              store.getClass.getMethods.filter(_.getName.contains("dbInstance")).head
+            Iterator(dbInstanceMethod.invoke(store).asInstanceOf[RocksDB].conf)
+        }.collect().head
+
+      // Verify the confs are same as those configured in the session conf
+      assert(rocksDBConfInTask.compactOnCommit == true)
+      assert(rocksDBConfInTask.lockAcquireTimeoutMs == 10L)
+    }
+  }
+
+  test("rocksdb file manager metrics exposed") {
+    import RocksDBStateStoreProvider._
+    def getCustomMetric(metrics: StateStoreMetrics, customMetric: StateStoreCustomMetric): Long = {
+      val metricPair = metrics.customMetrics.find(_._1.name == customMetric.name)
+      assert(metricPair.isDefined)
+      metricPair.get._2
+    }
+
+    val provider = newStoreProvider()
+    val store = provider.getStore(0)
+    // Verify state after updating
+    put(store, "a", 1)
+    assert(get(store, "a") === Some(1))
+    assert(store.commit() === 1)
+    assert(store.hasCommitted)
+    val storeMetrics = store.metrics
+    assert(storeMetrics.numKeys === 1)
+    assert(getCustomMetric(storeMetrics, CUSTOM_METRIC_FILES_COPIED) > 0L)
+    assert(getCustomMetric(storeMetrics, CUSTOM_METRIC_FILES_REUSED) == 0L)
+    assert(getCustomMetric(storeMetrics, CUSTOM_METRIC_BYTES_COPIED) > 0L)
+    assert(getCustomMetric(storeMetrics, CUSTOM_METRIC_ZIP_FILE_BYTES_UNCOMPRESSED) > 0L)
+  }
+
+  override def newStoreProvider(): RocksDBStateStoreProvider = {
+    newStoreProvider(StateStoreId(newDir(), Random.nextInt(), 0))
+  }
+
+  def newStoreProvider(storeId: StateStoreId): RocksDBStateStoreProvider = {
+    val keySchema = StructType(Seq(StructField("key", StringType, true)))
+    val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
+    val provider = new RocksDBStateStoreProvider()
+    provider.init(
+      storeId, keySchema, valueSchema, indexOrdinal = None, new StateStoreConf, new Configuration)
+    provider
+  }
+
+  override def getLatestData(storeProvider: RocksDBStateStoreProvider): Set[(String, Int)] = {
+    getData(storeProvider, version = -1)
+  }
+
+  override def getData(
+      provider: RocksDBStateStoreProvider,
+      version: Int = -1): Set[(String, Int)] = {
+    val reloadedProvider = newStoreProvider(provider.stateStoreId)
+    val versionToRead = if (version < 0) reloadedProvider.latestVersion else version
+    reloadedProvider.getStore(versionToRead).iterator().map(rowsToStringInt).toSet
+  }
+
+  override protected val keySchema = StructType(Seq(StructField("key", StringType, true)))
+  override protected val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
+
+  override def newStoreProvider(
+    minDeltasForSnapshot: Int,
+    numOfVersToRetainInMemory: Int): RocksDBStateStoreProvider = newStoreProvider()
+
+  override def getDefaultSQLConf(
+    minDeltasForSnapshot: Int,
+    numOfVersToRetainInMemory: Int): SQLConf = new SQLConf()
+}
+
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index 4323725..2990860 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -49,6 +49,7 @@ import org.apache.spark.util.Utils
 class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
   with BeforeAndAfter {
   import StateStoreTestsHelper._
+  import StateStoreCoordinatorSuite._
 
   override val keySchema = StructType(Seq(StructField("key", StringType, true)))
   override val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
@@ -235,6 +236,162 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     assert(getSizeOfStateForCurrentVersion(store.metrics) > noDataMemoryUsed)
   }
 
+  test("maintenance") {
+    val conf = new SparkConf()
+      .setMaster("local")
+      .setAppName("test")
+      // Make sure that when SparkContext stops, the StateStore maintenance thread 'quickly'
+      // fails to talk to the StateStoreCoordinator and unloads all the StateStores
+      .set(RPC_NUM_RETRIES, 1)
+    val opId = 0
+    val dir1 = newDir()
+    val storeProviderId1 = StateStoreProviderId(StateStoreId(dir1, opId, 0), UUID.randomUUID)
+    val dir2 = newDir()
+    val storeProviderId2 = StateStoreProviderId(StateStoreId(dir2, opId, 1), UUID.randomUUID)
+    val sqlConf = getDefaultSQLConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get,
+      SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get)
+    sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2)
+    // Make maintenance thread do snapshots and cleanups very fast
+    sqlConf.setConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL, 10L)
+    val storeConf = StateStoreConf(sqlConf)
+    val hadoopConf = new Configuration()
+    val provider = newStoreProvider(storeProviderId1.storeId)
+
+    var latestStoreVersion = 0
+
+    def generateStoreVersions(): Unit = {
+      for (i <- 1 to 20) {
+        val store = StateStore.get(storeProviderId1, keySchema, valueSchema, None,
+          latestStoreVersion, storeConf, hadoopConf)
+        put(store, "a", i)
+        store.commit()
+        latestStoreVersion += 1
+      }
+    }
+
+    val timeoutDuration = 1.minute
+
+    quietly {
+      withSpark(new SparkContext(conf)) { sc =>
+        withCoordinatorRef(sc) { coordinatorRef =>
+          require(!StateStore.isMaintenanceRunning, "StateStore is unexpectedly running")
+
+          // Generate sufficient versions of store for snapshots
+          generateStoreVersions()
+
+          eventually(timeout(timeoutDuration)) {
+            // Store should have been reported to the coordinator
+            assert(coordinatorRef.getLocation(storeProviderId1).nonEmpty,
+              "active instance was not reported")
+
+            // Background maintenance should clean up and generate snapshots
+            assert(StateStore.isMaintenanceRunning, "Maintenance task is not running")
+
+            // Some snapshots should have been generated
+            val snapshotVersions = (1 to latestStoreVersion).filter { version =>
+              fileExists(provider, version, isSnapshot = true)
+            }
+            assert(snapshotVersions.nonEmpty, "no snapshot file found")
+          }
+
+          // Generate more versions such that there is another snapshot and
+          // the earliest delta file will be cleaned up
+          generateStoreVersions()
+
+          // Earliest delta file should get cleaned up
+          eventually(timeout(timeoutDuration)) {
+            assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted")
+          }
+
+          // If driver decides to deactivate all stores related to a query run,
+          // then this instance should be unloaded
+          coordinatorRef.deactivateInstances(storeProviderId1.queryRunId)
+          eventually(timeout(timeoutDuration)) {
+            assert(!StateStore.isLoaded(storeProviderId1))
+          }
+
+          // Reload the store and verify
+          StateStore.get(storeProviderId1, keySchema, valueSchema, indexOrdinal = None,
+            latestStoreVersion, storeConf, hadoopConf)
+          assert(StateStore.isLoaded(storeProviderId1))
+
+          // If some other executor loads the store, then this instance should be unloaded
+          coordinatorRef
+            .reportActiveInstance(storeProviderId1, "other-host", "other-exec", Seq.empty)
+          eventually(timeout(timeoutDuration)) {
+            assert(!StateStore.isLoaded(storeProviderId1))
+          }
+
+          // Reload the store and verify
+          StateStore.get(storeProviderId1, keySchema, valueSchema, indexOrdinal = None,
+            latestStoreVersion, storeConf, hadoopConf)
+          assert(StateStore.isLoaded(storeProviderId1))
+
+          // If some other executor loads the store, and when this executor loads other store,
+          // then this executor should unload inactive instances immediately.
+          coordinatorRef
+            .reportActiveInstance(storeProviderId1, "other-host", "other-exec", Seq.empty)
+          StateStore.get(storeProviderId2, keySchema, valueSchema, indexOrdinal = None,
+            0, storeConf, hadoopConf)
+          assert(!StateStore.isLoaded(storeProviderId1))
+          assert(StateStore.isLoaded(storeProviderId2))
+        }
+      }
+
+      // Verify if instance is unloaded if SparkContext is stopped
+      eventually(timeout(timeoutDuration)) {
+        require(SparkEnv.get === null)
+        assert(!StateStore.isLoaded(storeProviderId1))
+        assert(!StateStore.isLoaded(storeProviderId2))
+        assert(!StateStore.isMaintenanceRunning)
+      }
+    }
+  }
+
+  test("snapshotting") {
+    val provider = newStoreProvider(minDeltasForSnapshot = 5, numOfVersToRetainInMemory = 2)
+
+    var currentVersion = 0
+
+    currentVersion = updateVersionTo(provider, currentVersion, 2)
+    require(getLatestData(provider) === Set("a" -> 2))
+    provider.doMaintenance()               // should not generate snapshot files
+    assert(getLatestData(provider) === Set("a" -> 2))
+
+    for (i <- 1 to currentVersion) {
+      assert(fileExists(provider, i, isSnapshot = false))  // all delta files present
+      assert(!fileExists(provider, i, isSnapshot = true))  // no snapshot files present
+    }
+
+    // After version 6, snapshotting should generate one snapshot file
+    currentVersion = updateVersionTo(provider, currentVersion, 6)
+    require(getLatestData(provider) === Set("a" -> 6), "store not updated correctly")
+    provider.doMaintenance()       // should generate snapshot files
+
+    val snapshotVersion = (0 to 6).find(version => fileExists(provider, version, isSnapshot = true))
+    assert(snapshotVersion.nonEmpty, "snapshot file not generated")
+    deleteFilesEarlierThanVersion(provider, snapshotVersion.get)
+    assert(
+      getData(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get),
+      "snapshotting messed up the data of the snapshotted version")
+    assert(
+      getLatestData(provider) === Set("a" -> 6),
+      "snapshotting messed up the data of the final version")
+
+    // After version 20, snapshotting should generate newer snapshot files
+    currentVersion = updateVersionTo(provider, currentVersion, 20)
+    require(getLatestData(provider) === Set("a" -> 20), "store not updated correctly")
+    provider.doMaintenance()       // do snapshot
+
+    val latestSnapshotVersion = (0 to 20).filter(version =>
+      fileExists(provider, version, isSnapshot = true)).lastOption
+    assert(latestSnapshotVersion.nonEmpty, "no snapshot file found")
+    assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated")
+
+    deleteFilesEarlierThanVersion(provider, latestSnapshotVersion.get)
+    assert(getLatestData(provider) === Set("a" -> 20), "snapshotting messed up the data")
+  }
+
   testQuietly("SPARK-18342: commit fails when rename fails") {
     import RenameReturnsFalseFileSystem._
     val dir = scheme + "://" + newDir()
@@ -582,7 +739,6 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
 abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
   extends StateStoreCodecsTest with PrivateMethodTester {
   import StateStoreTestsHelper._
-  import StateStoreCoordinatorSuite._
 
   type MapType = mutable.HashMap[UnsafeRow, UnsafeRow]
   type ProviderMapType = java.util.concurrent.ConcurrentHashMap[UnsafeRow, UnsafeRow]
@@ -761,118 +917,6 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
     assert(rowsToSet(finalStore.iterator()) === Set(key -> 2))
   }
 
-  test("maintenance") {
-    val conf = new SparkConf()
-      .setMaster("local")
-      .setAppName("test")
-      // Make sure that when SparkContext stops, the StateStore maintenance thread 'quickly'
-      // fails to talk to the StateStoreCoordinator and unloads all the StateStores
-      .set(RPC_NUM_RETRIES, 1)
-    val opId = 0
-    val dir1 = newDir()
-    val storeProviderId1 = StateStoreProviderId(StateStoreId(dir1, opId, 0), UUID.randomUUID)
-    val dir2 = newDir()
-    val storeProviderId2 = StateStoreProviderId(StateStoreId(dir2, opId, 1), UUID.randomUUID)
-    val sqlConf = getDefaultSQLConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get,
-      SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get)
-    sqlConf.setConf(SQLConf.MIN_BATCHES_TO_RETAIN, 2)
-    // Make maintenance thread do snapshots and cleanups very fast
-    sqlConf.setConf(SQLConf.STREAMING_MAINTENANCE_INTERVAL, 10L)
-    val storeConf = StateStoreConf(sqlConf)
-    val hadoopConf = new Configuration()
-    val provider = newStoreProvider(storeProviderId1.storeId)
-
-    var latestStoreVersion = 0
-
-    def generateStoreVersions(): Unit = {
-      for (i <- 1 to 20) {
-        val store = StateStore.get(storeProviderId1, keySchema, valueSchema, None,
-          latestStoreVersion, storeConf, hadoopConf)
-        put(store, "a", i)
-        store.commit()
-        latestStoreVersion += 1
-      }
-    }
-
-    val timeoutDuration = 1.minute
-
-    quietly {
-      withSpark(new SparkContext(conf)) { sc =>
-        withCoordinatorRef(sc) { coordinatorRef =>
-          require(!StateStore.isMaintenanceRunning, "StateStore is unexpectedly running")
-
-          // Generate sufficient versions of store for snapshots
-          generateStoreVersions()
-
-          eventually(timeout(timeoutDuration)) {
-            // Store should have been reported to the coordinator
-            assert(coordinatorRef.getLocation(storeProviderId1).nonEmpty,
-              "active instance was not reported")
-
-            // Background maintenance should clean up and generate snapshots
-            assert(StateStore.isMaintenanceRunning, "Maintenance task is not running")
-
-            // Some snapshots should have been generated
-            val snapshotVersions = (1 to latestStoreVersion).filter { version =>
-              fileExists(provider, version, isSnapshot = true)
-            }
-            assert(snapshotVersions.nonEmpty, "no snapshot file found")
-          }
-
-          // Generate more versions such that there is another snapshot and
-          // the earliest delta file will be cleaned up
-          generateStoreVersions()
-
-          // Earliest delta file should get cleaned up
-          eventually(timeout(timeoutDuration)) {
-            assert(!fileExists(provider, 1, isSnapshot = false), "earliest file not deleted")
-          }
-
-          // If driver decides to deactivate all stores related to a query run,
-          // then this instance should be unloaded
-          coordinatorRef.deactivateInstances(storeProviderId1.queryRunId)
-          eventually(timeout(timeoutDuration)) {
-            assert(!StateStore.isLoaded(storeProviderId1))
-          }
-
-          // Reload the store and verify
-          StateStore.get(storeProviderId1, keySchema, valueSchema, indexOrdinal = None,
-            latestStoreVersion, storeConf, hadoopConf)
-          assert(StateStore.isLoaded(storeProviderId1))
-
-          // If some other executor loads the store, then this instance should be unloaded
-          coordinatorRef
-            .reportActiveInstance(storeProviderId1, "other-host", "other-exec", Seq.empty)
-          eventually(timeout(timeoutDuration)) {
-            assert(!StateStore.isLoaded(storeProviderId1))
-          }
-
-          // Reload the store and verify
-          StateStore.get(storeProviderId1, keySchema, valueSchema, indexOrdinal = None,
-            latestStoreVersion, storeConf, hadoopConf)
-          assert(StateStore.isLoaded(storeProviderId1))
-
-          // If some other executor loads the store, and when this executor loads other store,
-          // then this executor should unload inactive instances immediately.
-          coordinatorRef
-            .reportActiveInstance(storeProviderId1, "other-host", "other-exec", Seq.empty)
-          StateStore.get(storeProviderId2, keySchema, valueSchema, indexOrdinal = None,
-            0, storeConf, hadoopConf)
-          assert(!StateStore.isLoaded(storeProviderId1))
-          assert(StateStore.isLoaded(storeProviderId2))
-        }
-      }
-
-      // Verify if instance is unloaded if SparkContext is stopped
-      eventually(timeout(timeoutDuration)) {
-        require(SparkEnv.get === null)
-        assert(!StateStore.isLoaded(storeProviderId1))
-        assert(!StateStore.isLoaded(storeProviderId2))
-        assert(!StateStore.isMaintenanceRunning)
-      }
-    }
-  }
-
   test("StateStore.get") {
     quietly {
       val dir = newDir()
@@ -925,50 +969,6 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
     }
   }
 
-  test("snapshotting") {
-    val provider = newStoreProvider(minDeltasForSnapshot = 5, numOfVersToRetainInMemory = 2)
-
-    var currentVersion = 0
-
-    currentVersion = updateVersionTo(provider, currentVersion, 2)
-    require(getLatestData(provider) === Set("a" -> 2))
-    provider.doMaintenance()               // should not generate snapshot files
-    assert(getLatestData(provider) === Set("a" -> 2))
-
-    for (i <- 1 to currentVersion) {
-      assert(fileExists(provider, i, isSnapshot = false))  // all delta files present
-      assert(!fileExists(provider, i, isSnapshot = true))  // no snapshot files present
-    }
-
-    // After version 6, snapshotting should generate one snapshot file
-    currentVersion = updateVersionTo(provider, currentVersion, 6)
-    require(getLatestData(provider) === Set("a" -> 6), "store not updated correctly")
-    provider.doMaintenance()       // should generate snapshot files
-
-    val snapshotVersion = (0 to 6).find(version => fileExists(provider, version, isSnapshot = true))
-    assert(snapshotVersion.nonEmpty, "snapshot file not generated")
-    deleteFilesEarlierThanVersion(provider, snapshotVersion.get)
-    assert(
-      getData(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get),
-      "snapshotting messed up the data of the snapshotted version")
-    assert(
-      getLatestData(provider) === Set("a" -> 6),
-      "snapshotting messed up the data of the final version")
-
-    // After version 20, snapshotting should generate newer snapshot files
-    currentVersion = updateVersionTo(provider, currentVersion, 20)
-    require(getLatestData(provider) === Set("a" -> 20), "store not updated correctly")
-    provider.doMaintenance()       // do snapshot
-
-    val latestSnapshotVersion = (0 to 20).filter(version =>
-      fileExists(provider, version, isSnapshot = true)).lastOption
-    assert(latestSnapshotVersion.nonEmpty, "no snapshot file found")
-    assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated")
-
-    deleteFilesEarlierThanVersion(provider, latestSnapshotVersion.get)
-    assert(getLatestData(provider) === Set("a" -> 20), "snapshotting messed up the data")
-  }
-
   test("reports memory usage") {
     val provider = newStoreProvider()
     val store = provider.getStore(0)

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org