You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by vi...@apache.org on 2021/07/12 16:17:54 UTC

[spark] branch branch-3.2 updated: [SPARK-35861][SS] Introduce "prefix match scan" feature on state store

This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new 07011eb  [SPARK-35861][SS] Introduce "prefix match scan" feature on state store
07011eb is described below

commit 07011eb77973a87433e208ab67f3068c54a66b4c
Author: Jungtaek Lim <ka...@gmail.com>
AuthorDate: Mon Jul 12 09:06:50 2021 -0700

    [SPARK-35861][SS] Introduce "prefix match scan" feature on state store
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to introduce a new feature "prefix match scan" on state store, which enables users of state store (mostly stateful operators) to group the keys into logical groups, and scan the keys in the same group efficiently.
    
    For example, if the schema of the key of state store is `[ sessionId | session.start ]`, we can scan with prefix key which schema is `[ sessionId ]` (leftmost 1 column) and retrieve all key-value pairs in state store which keys are matched with given prefix key.
    
    This PR will bring the API changes, though the changes are done in the developer API.
    
    * Registering the prefix key
    
    We propose to make an explicit change to the init() method of StateStoreProvider, as below:
    
    ```
    def init(
          stateStoreId: StateStoreId,
          keySchema: StructType,
          valueSchema: StructType,
          numColsPrefixKey: Int,
          storeConfs: StateStoreConf,
          hadoopConf: Configuration): Unit
    ```
    
    Please note that we remove an unused parameter “keyIndexOrdinal” as well. The parameter is coupled with getRange() which we will remove as well. See below for rationalization.
    
    Here we provide the number of columns we take to project the prefix key from the full key. If the operator doesn’t leverage prefix match scan, the value can (and should) be 0, because the state store provider may optimize the underlying storage format which may bring extra overhead.
    
    We would like to apply some restrictions on prefix key to simplify the functionality:
    
    * Prefix key is a part of the full key. It can’t be the same as the full key.
      * That said, the full key will be the (prefix key + remaining parts), and both prefix key and remaining parts should have at least one column.
    * We always take the columns from the leftmost sequentially, like “seq.take(nums)”.
    * We don’t allow reordering of the columns.
    * We only guarantee “equality” comparison against prefix keys, and don’t support the prefix “range” scan.
      * We only support scanning on the keys which match with the prefix key.
      * E.g. We don’t support the range scan from user A to user B due to technical complexity. That’s the reason we can’t leverage the existing getRange API.
    
    As we mentioned, we want to make an explicit change to the init() method of StateStoreProvider which would break backward compatibility, assuming that 3rd party state store providers need to update their code in any way to support prefix match scan. Given RocksDB state store provider is being donated to the OSS and plan to be available in Spark 3.2, the majority of the users would migrate to the built-in state store providers, which would remedy the concerns.
    
    * Scanning key-value pairs matched to the prefix key
    
    We propose to add a new method to the ReadStateStore (and StateStore by inheritance), as below:
    
    ```
    def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair]
    ```
    
    We require callers to pass the `prefixKey` which would have the same schema with the registered prefix key schema. In other words, the schema of the parameter `prefixKey` should match to the projection of the prefix key on the full key based on the number of columns for the prefix key.
    
    The method contract is clear - the method will return the iterator which will give the key-value pairs whose prefix key is matched with the given prefix key. Callers should only rely on the contract and should not expect any other characteristics based on specific details on the state store provider.
    
    In the caller’s point of view, the prefix key is only used for retrieving key-value pairs via prefix match scan. Callers should keep using the full key to do CRUD.
    
    Note that this PR also proposes to make a breaking change, removal of getRange(), which is never be implemented properly and hence never be called properly.
    
    ### Why are the changes needed?
    
    * Introducing prefix match scan feature
    
    Currently, the API in state store is only based on key-value data structure. This lacks on advanced data structures like list-like one, which required us to implement the data structure on our own whenever we need it. We had one in stream-stream join, and we were about to have another one in native session window. The custom implementation of data structure based on the state store API tends to be complicated and has to deal with multiple state stores.
    
    We decided to enhance the state store API a bit to remove the requirement for native session window to implement its own. From the operator of native session window, it will just need to do prefix scan on group key to retrieve all sessions belonging to the group key.
    
    Thanks to adding the feature to the part of state store API, this would enable state store providers to optimize the implementation based on the characteristic. (e.g. We will implement this in RocksDB state store provider via leveraging the characteristic that RocksDB sorts the key by natural order of binary format.)
    
    * Removal of getRange API
    
    Before introducing this we sought the way to leverage getRange, but it's quite hard to implement efficiently, with respecting its method contract. Spark always calls the method with (None, None) parameter and all the state store providers (including built-in) implement it as just calling iterator(), which is not respecting the method contract. That said, we can replace all getRange() usages to iterator(), and remove the API to remove any confusions/concerns.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes for the end users & maintainers of 3rd party state store provider. They will need to upgrade their state store provider implementations to adopt this change.
    
    ### How was this patch tested?
    
    Added UT, and also existing UTs to make sure it doesn't break anything.
    
    Closes #33038 from HeartSaVioR/SPARK-35861.
    
    Authored-by: Jungtaek Lim <ka...@gmail.com>
    Signed-off-by: Liang-Chi Hsieh <vi...@gmail.com>
    (cherry picked from commit 094300fa609e3028e29346641baae7174ca9a1c8)
    Signed-off-by: Liang-Chi Hsieh <vi...@gmail.com>
---
 .../streaming/FlatMapGroupsWithStateExec.scala     |   5 +-
 .../state/FlatMapGroupsWithStateExecHelper.scala   |   2 +-
 .../streaming/state/HDFSBackedStateStoreMap.scala  | 170 +++++++++++
 .../state/HDFSBackedStateStoreProvider.scala       |  94 +++---
 .../sql/execution/streaming/state/RocksDB.scala    |  35 +++
 .../streaming/state/RocksDBStateEncoder.scala      | 252 +++++++++++++++
 .../state/RocksDBStateStoreProvider.scala          | 117 ++-----
 .../sql/execution/streaming/state/StateStore.scala |  54 ++--
 .../execution/streaming/state/StateStoreRDD.scala  |   8 +-
 .../state/StreamingAggregationStateManager.scala   |   2 +-
 .../state/SymmetricHashJoinStateManager.scala      |   6 +-
 .../sql/execution/streaming/state/package.scala    |  12 +-
 .../execution/streaming/statefulOperators.scala    |   8 +-
 .../sql/execution/streaming/streamingLimits.scala  |   2 +-
 .../streaming/state/MemoryStateStore.scala         |   4 +
 .../streaming/state/RocksDBStateStoreSuite.scala   |  35 ++-
 .../streaming/state/StateStoreRDDSuite.scala       | 128 ++++----
 .../streaming/state/StateStoreSuite.scala          | 338 ++++++++++++---------
 .../apache/spark/sql/streaming/StreamSuite.scala   |   2 +-
 19 files changed, 861 insertions(+), 413 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
index 0e0fbe0..03694d4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala
@@ -221,7 +221,6 @@ case class FlatMapGroupsWithStateExec(
         // The state store aware zip partitions will provide us with two iterators,
         // child data iterator and the initial state iterator per partition.
         case (partitionId, childDataIterator, initStateIterator) =>
-
           val stateStoreId = StateStoreId(
             stateInfo.get.checkpointLocation, stateInfo.get.operatorId, partitionId)
           val storeProviderId = StateStoreProviderId(stateStoreId, stateInfo.get.queryRunId)
@@ -229,7 +228,7 @@ case class FlatMapGroupsWithStateExec(
             storeProviderId,
             groupingAttributes.toStructType,
             stateManager.stateSchema,
-            indexOrdinal = None,
+            numColsPrefixKey = 0,
             stateInfo.get.storeVersion, storeConf, hadoopConfBroadcast.value.value)
           val processor = new InputProcessor(store)
           processDataWithPartition(childDataIterator, store, processor, Some(initStateIterator))
@@ -239,7 +238,7 @@ case class FlatMapGroupsWithStateExec(
         getStateInfo,
         groupingAttributes.toStructType,
         stateManager.stateSchema,
-        indexOrdinal = None,
+        numColsPrefixKey = 0,
         session.sqlContext.sessionState,
         Some(session.sqlContext.streams.stateStoreCoordinator)
       ) { case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala
index 2d9824e..d396e71 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala
@@ -103,7 +103,7 @@ object FlatMapGroupsWithStateExecHelper {
 
     override def getAllState(store: StateStore): Iterator[StateData] = {
       val stateData = StateData()
-      store.getRange(None, None).map { p =>
+      store.iterator().map { p =>
         stateData.withNew(p.key, p.value, getStateObject(p.value), getTimestamp(p.value))
       }
     }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreMap.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreMap.scala
new file mode 100644
index 0000000..73608d4
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreMap.scala
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.types.{StructField, StructType}
+
+trait HDFSBackedStateStoreMap {
+  def size(): Int
+  def get(key: UnsafeRow): UnsafeRow
+  def put(key: UnsafeRow, value: UnsafeRow): UnsafeRow
+  def putAll(map: HDFSBackedStateStoreMap): Unit
+  def remove(key: UnsafeRow): UnsafeRow
+  def iterator(): Iterator[UnsafeRowPair]
+  def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair]
+  def clear(): Unit
+}
+
+object HDFSBackedStateStoreMap {
+  // ConcurrentHashMap is used because it generates fail-safe iterators on filtering
+  // - The iterator is weakly consistent with the map, i.e., iterator's data reflect the values in
+  //   the map when the iterator was created
+  // - Any updates to the map while iterating through the filtered iterator does not throw
+  //   java.util.ConcurrentModificationException
+  type MapType = java.util.concurrent.ConcurrentHashMap[UnsafeRow, UnsafeRow]
+
+  def create(keySchema: StructType, numColsPrefixKey: Int): HDFSBackedStateStoreMap = {
+    if (numColsPrefixKey > 0) {
+      new PrefixScannableHDFSBackedStateStoreMap(keySchema, numColsPrefixKey)
+    } else {
+      new NoPrefixHDFSBackedStateStoreMap()
+    }
+  }
+}
+
+class NoPrefixHDFSBackedStateStoreMap extends HDFSBackedStateStoreMap {
+  private val map = new HDFSBackedStateStoreMap.MapType()
+
+  override def size: Int = map.size()
+
+  override def get(key: UnsafeRow): UnsafeRow = map.get(key)
+
+  override def put(key: UnsafeRow, value: UnsafeRow): UnsafeRow = map.put(key, value)
+
+  def putAll(other: HDFSBackedStateStoreMap): Unit = {
+    other match {
+      case o: NoPrefixHDFSBackedStateStoreMap => map.putAll(o.map)
+      case _ => other.iterator().foreach { pair => put(pair.key, pair.value) }
+    }
+  }
+
+  override def remove(key: UnsafeRow): UnsafeRow = map.remove(key)
+
+  override def iterator(): Iterator[UnsafeRowPair] = {
+    val unsafeRowPair = new UnsafeRowPair()
+    map.entrySet.asScala.iterator.map { entry =>
+      unsafeRowPair.withRows(entry.getKey, entry.getValue)
+    }
+  }
+
+  override def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair] = {
+    throw new UnsupportedOperationException("Prefix scan is not supported!")
+  }
+
+  override def clear(): Unit = map.clear()
+}
+
+class PrefixScannableHDFSBackedStateStoreMap(
+    keySchema: StructType,
+    numColsPrefixKey: Int) extends HDFSBackedStateStoreMap {
+
+  private val map = new HDFSBackedStateStoreMap.MapType()
+
+  // We are using ConcurrentHashMap here with the same rationalization we use ConcurrentHashMap on
+  // HDFSBackedStateStoreMap.MapType.
+  private val prefixKeyToKeysMap = new java.util.concurrent.ConcurrentHashMap[
+    UnsafeRow, mutable.Set[UnsafeRow]]()
+
+  private val prefixKeyFieldsWithIdx: Seq[(StructField, Int)] = {
+    keySchema.zipWithIndex.take(numColsPrefixKey)
+  }
+
+  private val prefixKeyProjection: UnsafeProjection = {
+    val refs = prefixKeyFieldsWithIdx.map(x => BoundReference(x._2, x._1.dataType, x._1.nullable))
+    UnsafeProjection.create(refs)
+  }
+
+  override def size: Int = map.size()
+
+  override def get(key: UnsafeRow): UnsafeRow = map.get(key)
+
+  override def put(key: UnsafeRow, value: UnsafeRow): UnsafeRow = {
+    val ret = map.put(key, value)
+
+    val prefixKey = prefixKeyProjection(key).copy()
+    prefixKeyToKeysMap.compute(prefixKey, (_, v) => {
+      if (v == null) {
+        val set = new mutable.HashSet[UnsafeRow]()
+        set.add(key)
+        set
+      } else {
+        v.add(key)
+        v
+      }
+    })
+
+    ret
+  }
+
+  def putAll(other: HDFSBackedStateStoreMap): Unit = {
+    other match {
+      case o: PrefixScannableHDFSBackedStateStoreMap =>
+        map.putAll(o.map)
+        prefixKeyToKeysMap.putAll(o.prefixKeyToKeysMap)
+
+      case _ => other.iterator().foreach { pair => put(pair.key, pair.value) }
+    }
+  }
+
+  override def remove(key: UnsafeRow): UnsafeRow = {
+    val ret = map.remove(key)
+
+    if (ret != null) {
+      val prefixKey = prefixKeyProjection(key).copy()
+      prefixKeyToKeysMap.computeIfPresent(prefixKey, (_, v) => {
+        v.remove(key)
+        if (v.isEmpty) null else v
+      })
+    }
+
+    ret
+  }
+
+  override def iterator(): Iterator[UnsafeRowPair] = {
+    val unsafeRowPair = new UnsafeRowPair()
+    map.entrySet.asScala.iterator.map { entry =>
+      unsafeRowPair.withRows(entry.getKey, entry.getValue)
+    }
+  }
+
+  override def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair] = {
+    val unsafeRowPair = new UnsafeRowPair()
+    prefixKeyToKeysMap.getOrDefault(prefixKey, mutable.Set.empty[UnsafeRow])
+      .iterator
+      .map { key => unsafeRowPair.withRows(key, map.get(key)) }
+  }
+
+  override def clear(): Unit = {
+    map.clear()
+    prefixKeyToKeysMap.clear()
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index 6fdd39da..c604021 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -72,14 +72,7 @@ import org.apache.spark.util.{SizeEstimator, Utils}
  */
 private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider with Logging {
 
-  // ConcurrentHashMap is used because it generates fail-safe iterators on filtering
-  // - The iterator is weakly consistent with the map, i.e., iterator's data reflect the values in
-  //   the map when the iterator was created
-  // - Any updates to the map while iterating through the filtered iterator does not throw
-  //   java.util.ConcurrentModificationException
-  type MapType = java.util.concurrent.ConcurrentHashMap[UnsafeRow, UnsafeRow]
-
-  class HDFSBackedReadStateStore(val version: Long, map: MapType)
+  class HDFSBackedReadStateStore(val version: Long, map: HDFSBackedStateStoreMap)
     extends ReadStateStore {
 
     override def id: StateStoreId = HDFSBackedStateStoreProvider.this.stateStoreId
@@ -87,10 +80,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
     override def get(key: UnsafeRow): UnsafeRow = map.get(key)
 
     override def iterator(): Iterator[UnsafeRowPair] = {
-      val unsafeRowPair = new UnsafeRowPair()
-      map.entrySet.asScala.iterator.map { entry =>
-        unsafeRowPair.withRows(entry.getKey, entry.getValue)
-      }
+      map.iterator()
     }
 
     override def abort(): Unit = {}
@@ -98,10 +88,14 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
     override def toString(): String = {
       s"HDFSReadStateStore[id=(op=${id.operatorId},part=${id.partitionId}),dir=$baseDir]"
     }
+
+    override def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair] = {
+      map.prefixScan(prefixKey)
+    }
   }
 
   /** Implementation of [[StateStore]] API which is backed by an HDFS-compatible file system */
-  class HDFSBackedStateStore(val version: Long, mapToUpdate: MapType)
+  class HDFSBackedStateStore(val version: Long, mapToUpdate: HDFSBackedStateStoreMap)
     extends StateStore {
 
     /** Trait and classes representing the internal state of the store */
@@ -139,13 +133,6 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
       }
     }
 
-    override def getRange(
-        start: Option[UnsafeRow],
-        end: Option[UnsafeRow]): Iterator[UnsafeRowPair] = {
-      verify(state == UPDATING, "Cannot getRange after already committed or aborted")
-      iterator()
-    }
-
     /** Commit all the updates that have been made to the store, and return the new version. */
     override def commit(): Long = {
       verify(state == UPDATING, "Cannot commit after already committed or aborted")
@@ -179,11 +166,10 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
      * Get an iterator of all the store data.
      * This can be called only after committing all the updates made in the current thread.
      */
-    override def iterator(): Iterator[UnsafeRowPair] = {
-      val unsafeRowPair = new UnsafeRowPair()
-      mapToUpdate.entrySet.asScala.iterator.map { entry =>
-        unsafeRowPair.withRows(entry.getKey, entry.getValue)
-      }
+    override def iterator(): Iterator[UnsafeRowPair] = mapToUpdate.iterator()
+
+    override def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair] = {
+      mapToUpdate.prefixScan(prefixKey)
     }
 
     override def metrics: StateStoreMetrics = {
@@ -231,9 +217,9 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
     new HDFSBackedReadStateStore(version, newMap)
   }
 
-  private def getLoadedMapForStore(version: Long): MapType = synchronized {
+  private def getLoadedMapForStore(version: Long): HDFSBackedStateStoreMap = synchronized {
     require(version >= 0, "Version cannot be less than 0")
-    val newMap = new MapType()
+    val newMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
     if (version > 0) {
       newMap.putAll(loadMap(version))
     }
@@ -244,7 +230,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
       stateStoreId: StateStoreId,
       keySchema: StructType,
       valueSchema: StructType,
-      indexOrdinal: Option[Int], // for sorting the data
+      numColsPrefixKey: Int,
       storeConf: StateStoreConf,
       hadoopConf: Configuration): Unit = {
     this.stateStoreId_ = stateStoreId
@@ -253,6 +239,12 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
     this.storeConf = storeConf
     this.hadoopConf = hadoopConf
     this.numberOfVersionsToRetainInMemory = storeConf.maxVersionsToRetainInMemory
+
+    require((keySchema.length == 0 && numColsPrefixKey == 0) ||
+      (keySchema.length > numColsPrefixKey), "The number of columns in the key must be " +
+      "greater than the number of columns for prefix key!")
+    this.numColsPrefixKey = numColsPrefixKey
+
     fm.mkdirs(baseDir)
   }
 
@@ -291,11 +283,14 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
   @volatile private var storeConf: StateStoreConf = _
   @volatile private var hadoopConf: Configuration = _
   @volatile private var numberOfVersionsToRetainInMemory: Int = _
+  @volatile private var numColsPrefixKey: Int = 0
+
   // TODO: The validation should be moved to a higher level so that it works for all state store
   // implementations
   @volatile private var isValidated = false
 
-  private lazy val loadedMaps = new util.TreeMap[Long, MapType](Ordering[Long].reverse)
+  private lazy val loadedMaps = new util.TreeMap[Long, HDFSBackedStateStoreMap](
+    Ordering[Long].reverse)
   private lazy val baseDir = stateStoreId.storeCheckpointLocation()
   private lazy val fm = CheckpointFileManager.create(baseDir, hadoopConf)
   private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
@@ -317,7 +312,10 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
 
   private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean)
 
-  private def commitUpdates(newVersion: Long, map: MapType, output: DataOutputStream): Unit = {
+  private def commitUpdates(
+      newVersion: Long,
+      map: HDFSBackedStateStoreMap,
+      output: DataOutputStream): Unit = {
     synchronized {
       finalizeDeltaFile(output)
       putStateIntoStateCacheMap(newVersion, map)
@@ -332,21 +330,20 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
     val versionsInFiles = fetchFiles().map(_.version).toSet
     val versionsLoaded = loadedMaps.keySet.asScala
     val allKnownVersions = versionsInFiles ++ versionsLoaded
-    val unsafeRowTuple = new UnsafeRowPair()
     if (allKnownVersions.nonEmpty) {
-      loadMap(allKnownVersions.max).entrySet().iterator().asScala.map { entry =>
-        unsafeRowTuple.withRows(entry.getKey, entry.getValue)
-      }
+      loadMap(allKnownVersions.max).iterator()
     } else Iterator.empty
   }
 
   /** This method is intended to be only used for unit test(s). DO NOT TOUCH ELEMENTS IN MAP! */
-  private[state] def getLoadedMaps(): util.SortedMap[Long, MapType] = synchronized {
+  private[state] def getLoadedMaps(): util.SortedMap[Long, HDFSBackedStateStoreMap] = synchronized {
     // shallow copy as a minimal guard
-    loadedMaps.clone().asInstanceOf[util.SortedMap[Long, MapType]]
+    loadedMaps.clone().asInstanceOf[util.SortedMap[Long, HDFSBackedStateStoreMap]]
   }
 
-  private def putStateIntoStateCacheMap(newVersion: Long, map: MapType): Unit = synchronized {
+  private def putStateIntoStateCacheMap(
+      newVersion: Long,
+      map: HDFSBackedStateStoreMap): Unit = synchronized {
     if (numberOfVersionsToRetainInMemory <= 0) {
       if (loadedMaps.size() > 0) loadedMaps.clear()
       return
@@ -373,7 +370,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
   }
 
   /** Load the required version of the map data from the backing files */
-  private def loadMap(version: Long): MapType = {
+  private def loadMap(version: Long): HDFSBackedStateStoreMap = {
 
     // Shortcut if the map for this version is already there to avoid a redundant put.
     val loadedCurrentVersionMap = synchronized { Option(loadedMaps.get(version)) }
@@ -398,13 +395,13 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
       // Find the most recent map before this version that we can.
       // [SPARK-22305] This must be done iteratively to avoid stack overflow.
       var lastAvailableVersion = version
-      var lastAvailableMap: Option[MapType] = None
+      var lastAvailableMap: Option[HDFSBackedStateStoreMap] = None
       while (lastAvailableMap.isEmpty) {
         lastAvailableVersion -= 1
 
         if (lastAvailableVersion <= 0) {
           // Use an empty map for versions 0 or less.
-          lastAvailableMap = Some(new MapType)
+          lastAvailableMap = Some(HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey))
         } else {
           lastAvailableMap =
             synchronized { Option(loadedMaps.get(lastAvailableVersion)) }
@@ -414,7 +411,8 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
 
       // Load all the deltas from the version after the last available one up to the target version.
       // The last available version is the one with a full snapshot, so it doesn't need deltas.
-      val resultMap = new MapType(lastAvailableMap.get)
+      val resultMap = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
+      resultMap.putAll(lastAvailableMap.get)
       for (deltaVersion <- lastAvailableVersion + 1 to version) {
         updateFromDeltaFile(deltaVersion, resultMap)
       }
@@ -452,7 +450,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
     output.close()
   }
 
-  private def updateFromDeltaFile(version: Long, map: MapType): Unit = {
+  private def updateFromDeltaFile(version: Long, map: HDFSBackedStateStoreMap): Unit = {
     val fileToRead = deltaFile(version)
     var input: DataInputStream = null
     val sourceStream = try {
@@ -506,18 +504,18 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
     logInfo(s"Read delta file for version $version of $this from $fileToRead")
   }
 
-  private def writeSnapshotFile(version: Long, map: MapType): Unit = {
+  private def writeSnapshotFile(version: Long, map: HDFSBackedStateStoreMap): Unit = {
     val targetFile = snapshotFile(version)
     var rawOutput: CancellableFSDataOutputStream = null
     var output: DataOutputStream = null
     try {
       rawOutput = fm.createAtomic(targetFile, overwriteIfPossible = true)
       output = compressStream(rawOutput)
-      val iter = map.entrySet().iterator()
+      val iter = map.iterator()
       while(iter.hasNext) {
         val entry = iter.next()
-        val keyBytes = entry.getKey.getBytes()
-        val valueBytes = entry.getValue.getBytes()
+        val keyBytes = entry.key.getBytes()
+        val valueBytes = entry.value.getBytes()
         output.writeInt(keyBytes.size)
         output.write(keyBytes)
         output.writeInt(valueBytes.size)
@@ -554,9 +552,9 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
     }
   }
 
-  private def readSnapshotFile(version: Long): Option[MapType] = {
+  private def readSnapshotFile(version: Long): Option[HDFSBackedStateStoreMap] = {
     val fileToRead = snapshotFile(version)
-    val map = new MapType()
+    val map = HDFSBackedStateStoreMap.create(keySchema, numColsPrefixKey)
     var input: DataInputStream = null
 
     try {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
index 9952d5d..9b8569e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
@@ -22,6 +22,7 @@ import java.util.Locale
 import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.{mutable, Map}
+import scala.collection.JavaConverters._
 import scala.ref.WeakReference
 import scala.util.Try
 
@@ -92,6 +93,9 @@ class RocksDB(
   @GuardedBy("acquireLock")
   @volatile private var acquiredThreadInfo: AcquiredThreadInfo = _
 
+  private val prefixScanReuseIter =
+    new java.util.concurrent.ConcurrentHashMap[Long, RocksIterator]()
+
   /**
    * Load the given version of data in a native RocksDB instance.
    * Note that this will copy all the necessary file from DFS to local disk as needed,
@@ -185,6 +189,33 @@ class RocksDB(
     }
   }
 
+  def prefixScan(prefix: Array[Byte]): Iterator[ByteArrayPair] = {
+    val threadId = Thread.currentThread().getId
+    val iter = prefixScanReuseIter.computeIfAbsent(threadId, tid => {
+      val it = writeBatch.newIteratorWithBase(db.newIterator())
+      logInfo(s"Getting iterator from version $loadedVersion for prefix scan on " +
+        s"thread ID $tid")
+      it
+    })
+
+    iter.seek(prefix)
+
+    new NextIterator[ByteArrayPair] {
+      override protected def getNext(): ByteArrayPair = {
+        if (iter.isValid && iter.key().take(prefix.length).sameElements(prefix)) {
+          byteArrayPair.set(iter.key, iter.value)
+          iter.next()
+          byteArrayPair
+        } else {
+          finished = true
+          null
+        }
+      }
+
+      override protected def close(): Unit = {}
+    }
+  }
+
   /**
    * Commit all the updates made as a version to DFS. The steps it needs to do to commits are:
    * - Write all the updates to the native RocksDB
@@ -254,6 +285,8 @@ class RocksDB(
    * Drop uncommitted changes, and roll back to previous version.
    */
   def rollback(): Unit = {
+    prefixScanReuseIter.entrySet().asScala.foreach(_.getValue.close())
+    prefixScanReuseIter.clear()
     writeBatch.clear()
     numKeysOnWritingVersion = numKeysOnLoadedVersion
     release()
@@ -269,6 +302,8 @@ class RocksDB(
 
   /** Release all resources */
   def close(): Unit = {
+    prefixScanReuseIter.entrySet().asScala.foreach(_.getValue.close())
+    prefixScanReuseIter.clear()
     try {
       closeDB()
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
new file mode 100644
index 0000000..81755e5
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
@@ -0,0 +1,252 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.state
+
+import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{STATE_ENCODING_NUM_VERSION_BYTES, STATE_ENCODING_VERSION}
+import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.unsafe.Platform
+
+sealed trait RocksDBStateEncoder {
+  def supportPrefixKeyScan: Boolean
+  def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte]
+  def extractPrefixKey(key: UnsafeRow): UnsafeRow
+
+  def encodeKey(row: UnsafeRow): Array[Byte]
+  def encodeValue(row: UnsafeRow): Array[Byte]
+
+  def decodeKey(keyBytes: Array[Byte]): UnsafeRow
+  def decodeValue(valueBytes: Array[Byte]): UnsafeRow
+  def decode(byteArrayTuple: ByteArrayPair): UnsafeRowPair
+}
+
+object RocksDBStateEncoder {
+  def getEncoder(
+      keySchema: StructType,
+      valueSchema: StructType,
+      numColsPrefixKey: Int): RocksDBStateEncoder = {
+    if (numColsPrefixKey > 0) {
+      new PrefixKeyScanStateEncoder(keySchema, valueSchema, numColsPrefixKey)
+    } else {
+      new NoPrefixKeyStateEncoder(keySchema, valueSchema)
+    }
+  }
+
+  /**
+   * Encode the UnsafeRow of N bytes as a N+1 byte array.
+   * @note This creates a new byte array and memcopies the UnsafeRow to the new array.
+   */
+  def encodeUnsafeRow(row: UnsafeRow): Array[Byte] = {
+    val bytesToEncode = row.getBytes
+    val encodedBytes = new Array[Byte](bytesToEncode.length + STATE_ENCODING_NUM_VERSION_BYTES)
+    Platform.putByte(encodedBytes, Platform.BYTE_ARRAY_OFFSET, STATE_ENCODING_VERSION)
+    // Platform.BYTE_ARRAY_OFFSET is the recommended way to memcopy b/w byte arrays. See Platform.
+    Platform.copyMemory(
+      bytesToEncode, Platform.BYTE_ARRAY_OFFSET,
+      encodedBytes, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES,
+      bytesToEncode.length)
+    encodedBytes
+  }
+
+  def decodeToUnsafeRow(bytes: Array[Byte], numFields: Int): UnsafeRow = {
+    if (bytes != null) {
+      val row = new UnsafeRow(numFields)
+      decodeToUnsafeRow(bytes, row)
+    } else {
+      null
+    }
+  }
+
+  def decodeToUnsafeRow(bytes: Array[Byte], reusedRow: UnsafeRow): UnsafeRow = {
+    if (bytes != null) {
+      // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st offset. See Platform.
+      reusedRow.pointTo(
+        bytes,
+        Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES,
+        bytes.length - STATE_ENCODING_NUM_VERSION_BYTES)
+      reusedRow
+    } else {
+      null
+    }
+  }
+}
+
+class PrefixKeyScanStateEncoder(
+    keySchema: StructType,
+    valueSchema: StructType,
+    numColsPrefixKey: Int) extends RocksDBStateEncoder {
+
+  import RocksDBStateEncoder._
+
+  require(keySchema.length > numColsPrefixKey, "The number of columns in the key must be " +
+    "greater than the number of columns for prefix key!")
+
+  private val prefixKeyFieldsWithIdx: Seq[(StructField, Int)] = {
+    keySchema.zipWithIndex.take(numColsPrefixKey)
+  }
+
+  private val remainingKeyFieldsWithIdx: Seq[(StructField, Int)] = {
+    keySchema.zipWithIndex.drop(numColsPrefixKey)
+  }
+
+  private val prefixKeyProjection: UnsafeProjection = {
+    val refs = prefixKeyFieldsWithIdx.map(x => BoundReference(x._2, x._1.dataType, x._1.nullable))
+    UnsafeProjection.create(refs)
+  }
+
+  private val remainingKeyProjection: UnsafeProjection = {
+    val refs = remainingKeyFieldsWithIdx.map(x =>
+      BoundReference(x._2, x._1.dataType, x._1.nullable))
+    UnsafeProjection.create(refs)
+  }
+
+  // This is quite simple to do - just bind sequentially, as we don't change the order.
+  private val restoreKeyProjection: UnsafeProjection = UnsafeProjection.create(keySchema)
+
+  // Reusable objects
+  private val joinedRowOnKey = new JoinedRow()
+  private val valueRow = new UnsafeRow(valueSchema.size)
+  private val rowTuple = new UnsafeRowPair()
+
+  override def encodeKey(row: UnsafeRow): Array[Byte] = {
+    val prefixKeyEncoded = encodeUnsafeRow(extractPrefixKey(row))
+    val remainingEncoded = encodeUnsafeRow(remainingKeyProjection(row))
+
+    val encodedBytes = new Array[Byte](prefixKeyEncoded.length + remainingEncoded.length + 4)
+    Platform.putInt(encodedBytes, Platform.BYTE_ARRAY_OFFSET, prefixKeyEncoded.length)
+    Platform.copyMemory(prefixKeyEncoded, Platform.BYTE_ARRAY_OFFSET,
+      encodedBytes, Platform.BYTE_ARRAY_OFFSET + 4, prefixKeyEncoded.length)
+    // NOTE: We don't put the length of remainingEncoded as we can calculate later
+    // on deserialization.
+    Platform.copyMemory(remainingEncoded, Platform.BYTE_ARRAY_OFFSET,
+      encodedBytes, Platform.BYTE_ARRAY_OFFSET + 4 + prefixKeyEncoded.length,
+      remainingEncoded.length)
+
+    encodedBytes
+  }
+
+  override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)
+
+  override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
+    val prefixKeyEncodedLen = Platform.getInt(keyBytes, Platform.BYTE_ARRAY_OFFSET)
+    val prefixKeyEncoded = new Array[Byte](prefixKeyEncodedLen)
+    Platform.copyMemory(keyBytes, Platform.BYTE_ARRAY_OFFSET + 4, prefixKeyEncoded,
+      Platform.BYTE_ARRAY_OFFSET, prefixKeyEncodedLen)
+
+    // Here we calculate the remainingKeyEncodedLen leveraging the length of keyBytes
+    val remainingKeyEncodedLen = keyBytes.length - 4 - prefixKeyEncodedLen
+
+    val remainingKeyEncoded = new Array[Byte](remainingKeyEncodedLen)
+    Platform.copyMemory(keyBytes, Platform.BYTE_ARRAY_OFFSET + 4 +
+      prefixKeyEncodedLen, remainingKeyEncoded, Platform.BYTE_ARRAY_OFFSET,
+      remainingKeyEncodedLen)
+
+    val prefixKeyDecoded = decodeToUnsafeRow(prefixKeyEncoded, numFields = numColsPrefixKey)
+    val remainingKeyDecoded = decodeToUnsafeRow(remainingKeyEncoded,
+      numFields = keySchema.length - numColsPrefixKey)
+
+    restoreKeyProjection(joinedRowOnKey.withLeft(prefixKeyDecoded).withRight(remainingKeyDecoded))
+  }
+
+  override def decodeValue(valueBytes: Array[Byte]): UnsafeRow = {
+    decodeToUnsafeRow(valueBytes, valueRow)
+  }
+
+  override def extractPrefixKey(key: UnsafeRow): UnsafeRow = {
+    prefixKeyProjection(key)
+  }
+
+  override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = {
+    val prefixKeyEncoded = encodeUnsafeRow(prefixKey)
+    val prefix = new Array[Byte](prefixKeyEncoded.length + 4)
+    Platform.putInt(prefix, Platform.BYTE_ARRAY_OFFSET, prefixKeyEncoded.length)
+    Platform.copyMemory(prefixKeyEncoded, Platform.BYTE_ARRAY_OFFSET, prefix,
+      Platform.BYTE_ARRAY_OFFSET + 4, prefixKeyEncoded.length)
+    prefix
+  }
+
+  override def decode(byteArrayTuple: ByteArrayPair): UnsafeRowPair = {
+    rowTuple.withRows(decodeKey(byteArrayTuple.key), decodeValue(byteArrayTuple.value))
+  }
+
+  override def supportPrefixKeyScan: Boolean = true
+}
+
+/**
+ * Encodes/decodes UnsafeRows to versioned byte arrays.
+ * It uses the first byte of the generated byte array to store the version the describes how the
+ * row is encoded in the rest of the byte array. Currently, the default version is 0,
+ *
+ * VERSION 0:  [ VERSION (1 byte) | ROW (N bytes) ]
+ *    The bytes of a UnsafeRow is written unmodified to starting from offset 1
+ *    (offset 0 is the version byte of value 0). That is, if the unsafe row has N bytes,
+ *    then the generated array byte will be N+1 bytes.
+ */
+class NoPrefixKeyStateEncoder(keySchema: StructType, valueSchema: StructType)
+  extends RocksDBStateEncoder {
+
+  import RocksDBStateEncoder._
+
+  // Reusable objects
+  private val keyRow = new UnsafeRow(keySchema.size)
+  private val valueRow = new UnsafeRow(valueSchema.size)
+  private val rowTuple = new UnsafeRowPair()
+
+  override def encodeKey(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)
+
+  override def encodeValue(row: UnsafeRow): Array[Byte] = encodeUnsafeRow(row)
+
+  /**
+   * Decode byte array for a key to a UnsafeRow.
+   * @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to
+   *       the given byte array.
+   */
+  override def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
+    decodeToUnsafeRow(keyBytes, keyRow)
+  }
+
+  /**
+   * Decode byte array for a value to a UnsafeRow.
+   *
+   * @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to
+   *       the given byte array.
+   */
+  override def decodeValue(valueBytes: Array[Byte]): UnsafeRow = {
+    decodeToUnsafeRow(valueBytes, valueRow)
+  }
+
+  /**
+   * Decode pair of key-value byte arrays in a pair of key-value UnsafeRows.
+   *
+   * @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to
+   *       the given byte array.
+   */
+  override def decode(byteArrayTuple: ByteArrayPair): UnsafeRowPair = {
+    rowTuple.withRows(decodeKey(byteArrayTuple.key), decodeValue(byteArrayTuple.value))
+  }
+
+  override def supportPrefixKeyScan: Boolean = false
+
+  override def extractPrefixKey(key: UnsafeRow): UnsafeRow = {
+    throw new IllegalStateException("This encoder doesn't support prefix key!")
+  }
+
+  override def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte] = {
+    throw new IllegalStateException("This encoder doesn't support prefix key!")
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
index 3ebaa8c..2e39dea 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
@@ -25,10 +25,9 @@ import org.apache.spark.{SparkConf, SparkEnv}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.unsafe.Platform
 import org.apache.spark.util.Utils
 
-private[state] class RocksDBStateStoreProvider
+private[sql] class RocksDBStateStoreProvider
   extends StateStoreProvider with Logging with Closeable {
   import RocksDBStateStoreProvider._
 
@@ -48,7 +47,7 @@ private[state] class RocksDBStateStoreProvider
 
     override def get(key: UnsafeRow): UnsafeRow = {
       verify(key != null, "Key cannot be null")
-      val value = encoder.decodeValue(rocksDB.get(encoder.encode(key)))
+      val value = encoder.decodeValue(rocksDB.get(encoder.encodeKey(key)))
       if (!isValidated && value != null) {
         StateStoreProvider.validateStateRowFormat(
           key, keySchema, value, valueSchema, storeConf)
@@ -61,20 +60,13 @@ private[state] class RocksDBStateStoreProvider
       verify(state == UPDATING, "Cannot put after already committed or aborted")
       verify(key != null, "Key cannot be null")
       require(value != null, "Cannot put a null value")
-      rocksDB.put(encoder.encode(key), encoder.encode(value))
+      rocksDB.put(encoder.encodeKey(key), encoder.encodeValue(value))
     }
 
     override def remove(key: UnsafeRow): Unit = {
       verify(state == UPDATING, "Cannot remove after already committed or aborted")
       verify(key != null, "Key cannot be null")
-      rocksDB.remove(encoder.encode(key))
-    }
-
-    override def getRange(
-        start: Option[UnsafeRow],
-        end: Option[UnsafeRow]): Iterator[UnsafeRowPair] = {
-      verify(state == UPDATING, "Cannot call getRange() after already committed or aborted")
-      iterator()
+      rocksDB.remove(encoder.encodeKey(key))
     }
 
     override def iterator(): Iterator[UnsafeRowPair] = {
@@ -89,6 +81,13 @@ private[state] class RocksDBStateStoreProvider
       }
     }
 
+    override def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair] = {
+      require(encoder.supportPrefixKeyScan, "Prefix scan requires setting prefix key!")
+
+      val prefix = encoder.encodePrefixKey(prefixKey)
+      rocksDB.prefixScan(prefix).map(kv => encoder.decode(kv))
+    }
+
     override def commit(): Long = synchronized {
       verify(state == UPDATING, "Cannot commit after already committed or aborted")
       val newVersion = rocksDB.commit()
@@ -147,7 +146,7 @@ private[state] class RocksDBStateStoreProvider
       stateStoreId: StateStoreId,
       keySchema: StructType,
       valueSchema: StructType,
-      indexOrdinal: Option[Int],
+      numColsPrefixKey: Int,
       storeConf: StateStoreConf,
       hadoopConf: Configuration): Unit = {
     this.stateStoreId_ = stateStoreId
@@ -155,6 +154,13 @@ private[state] class RocksDBStateStoreProvider
     this.valueSchema = valueSchema
     this.storeConf = storeConf
     this.hadoopConf = hadoopConf
+
+    require((keySchema.length == 0 && numColsPrefixKey == 0) ||
+      (keySchema.length > numColsPrefixKey), "The number of columns in the key must be " +
+      "greater than the number of columns for prefix key!")
+
+    this.encoder = RocksDBStateEncoder.getEncoder(keySchema, valueSchema, numColsPrefixKey)
+
     rocksDB // lazy initialization
   }
 
@@ -195,93 +201,11 @@ private[state] class RocksDBStateStoreProvider
     new RocksDB(dfsRootDir, RocksDBConf(storeConf), localRootDir, hadoopConf, storeIdStr)
   }
 
-  private lazy val encoder = new StateEncoder
+  @volatile private var encoder: RocksDBStateEncoder = _
 
   private def verify(condition: => Boolean, msg: String): Unit = {
     if (!condition) { throw new IllegalStateException(msg) }
   }
-
-  /**
-   * Encodes/decodes UnsafeRows to versioned byte arrays.
-   * It uses the first byte of the generated byte array to store the version that describes how the
-   * row is encoded in the rest of the byte array. Currently, the default version is 0,
-   *
-   * VERSION 0:  [ VERSION (1 byte) | ROW (N bytes) ]
-   *    The bytes of a UnsafeRow is written unmodified to starting from offset 1
-   *    (offset 0 is the version byte of value 0). That is, if the unsafe row has N bytes,
-   *    then the generated array byte will be N+1 bytes.
-   */
-  class StateEncoder {
-    import RocksDBStateStoreProvider._
-
-    // Reusable objects
-    private val keyRow = new UnsafeRow(keySchema.size)
-    private val valueRow = new UnsafeRow(valueSchema.size)
-    private val rowTuple = new UnsafeRowPair()
-
-    /**
-     * Encode the UnsafeRow of N bytes as a N+1 byte array.
-     * @note This creates a new byte array and memcopies the UnsafeRow to the new array.
-     */
-    def encode(row: UnsafeRow): Array[Byte] = {
-      val bytesToEncode = row.getBytes
-      val encodedBytes = new Array[Byte](bytesToEncode.length + STATE_ENCODING_NUM_VERSION_BYTES)
-      Platform.putByte(encodedBytes, Platform.BYTE_ARRAY_OFFSET, STATE_ENCODING_VERSION)
-      // Platform.BYTE_ARRAY_OFFSET is the recommended way to memcopy b/w byte arrays. See Platform.
-      Platform.copyMemory(
-        bytesToEncode, Platform.BYTE_ARRAY_OFFSET,
-        encodedBytes, Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES,
-        bytesToEncode.length)
-      encodedBytes
-    }
-
-    /**
-     * Decode byte array for a key to a UnsafeRow.
-     * @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to
-     *       the given byte array.
-     */
-    def decodeKey(keyBytes: Array[Byte]): UnsafeRow = {
-      if (keyBytes != null) {
-        // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st offset. See Platform.
-        keyRow.pointTo(
-          keyBytes,
-          Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES,
-          keyBytes.length - STATE_ENCODING_NUM_VERSION_BYTES)
-        keyRow
-      } else {
-        null
-      }
-    }
-
-    /**
-     * Decode byte array for a value to a UnsafeRow.
-     *
-     * @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to
-     *       the given byte array.
-     */
-    def decodeValue(valueBytes: Array[Byte]): UnsafeRow = {
-      if (valueBytes != null) {
-        // Platform.BYTE_ARRAY_OFFSET is the recommended way refer to the 1st offset. See Platform.
-        valueRow.pointTo(
-          valueBytes,
-          Platform.BYTE_ARRAY_OFFSET + STATE_ENCODING_NUM_VERSION_BYTES,
-          valueBytes.size - STATE_ENCODING_NUM_VERSION_BYTES)
-        valueRow
-      } else {
-        null
-      }
-    }
-
-    /**
-     * Decode pair of key-value byte arrays in a pair of key-value UnsafeRows.
-     *
-     * @note The UnsafeRow returned is reused across calls, and the UnsafeRow just points to
-     *       the given byte array.
-     */
-    def decode(byteArrayTuple: ByteArrayPair): UnsafeRowPair = {
-      rowTuple.withRows(decodeKey(byteArrayTuple.key), decodeValue(byteArrayTuple.value))
-    }
-  }
 }
 
 object RocksDBStateStoreProvider {
@@ -328,4 +252,3 @@ object RocksDBStateStoreProvider {
     CUSTOM_METRIC_ZIP_FILE_BYTES_UNCOMPRESSED
   )
 }
-
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index 60ad318..5020638 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -44,6 +44,11 @@ import org.apache.spark.util.{ThreadUtils, Utils}
  *
  * `abort` method will be called when the task is completed - please clean up the resources in
  * the method.
+ *
+ * IMPLEMENTATION NOTES:
+ * * The implementation can throw exception on calling prefixScan method if the functionality is
+ *   not supported yet from the implementation. Note that some stateful operations would not work
+ *   on disabling prefixScan functionality.
  */
 trait ReadStateStore {
 
@@ -60,21 +65,17 @@ trait ReadStateStore {
   def get(key: UnsafeRow): UnsafeRow
 
   /**
-   * Get key value pairs with optional approximate `start` and `end` extents.
-   * If the State Store implementation maintains indices for the data based on the optional
-   * `keyIndexOrdinal` over fields `keySchema` (see `StateStoreProvider.init()`), then it can use
-   * `start` and `end` to make a best-effort scan over the data. Default implementation returns
-   * the full data scan iterator, which is correct but inefficient. Custom implementations must
-   * ensure that updates (puts, removes) can be made while iterating over this iterator.
+   * Return an iterator containing all the key-value pairs which are matched with
+   * the given prefix key.
+   *
+   * The operator will provide numColsPrefixKey greater than 0 in StateStoreProvider.init method
+   * if the operator needs to leverage the "prefix scan" feature. The schema of the prefix key
+   * should be same with the leftmost `numColsPrefixKey` columns of the key schema.
    *
-   * @param start UnsafeRow having the `keyIndexOrdinal` column set with appropriate starting value.
-   * @param end UnsafeRow having the `keyIndexOrdinal` column set with appropriate ending value.
-   * @return An iterator of key-value pairs that is guaranteed not miss any key between start and
-   *         end, both inclusive.
+   * It is expected to throw exception if Spark calls this method without setting numColsPrefixKey
+   * to the greater than 0.
    */
-  def getRange(start: Option[UnsafeRow], end: Option[UnsafeRow]): Iterator[UnsafeRowPair] = {
-    iterator()
-  }
+  def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair]
 
   /** Return an iterator containing all the key-value pairs in the StateStore. */
   def iterator(): Iterator[UnsafeRowPair]
@@ -149,6 +150,9 @@ class WrappedReadStateStore(store: StateStore) extends ReadStateStore {
   override def iterator(): Iterator[UnsafeRowPair] = store.iterator()
 
   override def abort(): Unit = store.abort()
+
+  override def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair] =
+    store.prefixScan(prefixKey)
 }
 
 /**
@@ -251,8 +255,9 @@ trait StateStoreProvider {
    * @param stateStoreId Id of the versioned StateStores that this provider will generate
    * @param keySchema Schema of keys to be stored
    * @param valueSchema Schema of value to be stored
-   * @param keyIndexOrdinal Optional column (represent as the ordinal of the field in keySchema) by
-   *                        which the StateStore implementation could index the data.
+   * @param numColsPrefixKey The number of leftmost columns to be used as prefix key.
+   *                         A value not greater than 0 means the operator doesn't activate prefix
+   *                         key, and the operator should not call prefixScan method in StateStore.
    * @param storeConfs Configurations used by the StateStores
    * @param hadoopConf Hadoop configuration that could be used by StateStore to save state data
    */
@@ -260,7 +265,7 @@ trait StateStoreProvider {
       stateStoreId: StateStoreId,
       keySchema: StructType,
       valueSchema: StructType,
-      keyIndexOrdinal: Option[Int], // for sorting the data by their keys
+      numColsPrefixKey: Int,
       storeConfs: StateStoreConf,
       hadoopConf: Configuration): Unit
 
@@ -313,11 +318,12 @@ object StateStoreProvider {
       providerId: StateStoreProviderId,
       keySchema: StructType,
       valueSchema: StructType,
-      indexOrdinal: Option[Int], // for sorting the data
+      numColsPrefixKey: Int,
       storeConf: StateStoreConf,
       hadoopConf: Configuration): StateStoreProvider = {
     val provider = create(storeConf.providerClass)
-    provider.init(providerId.storeId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf)
+    provider.init(providerId.storeId, keySchema, valueSchema, numColsPrefixKey,
+      storeConf, hadoopConf)
     provider
   }
 
@@ -465,13 +471,13 @@ object StateStore extends Logging {
       storeProviderId: StateStoreProviderId,
       keySchema: StructType,
       valueSchema: StructType,
-      indexOrdinal: Option[Int],
+      numColsPrefixKey: Int,
       version: Long,
       storeConf: StateStoreConf,
       hadoopConf: Configuration): ReadStateStore = {
     require(version >= 0)
     val storeProvider = getStateStoreProvider(storeProviderId, keySchema, valueSchema,
-      indexOrdinal, storeConf, hadoopConf)
+      numColsPrefixKey, storeConf, hadoopConf)
     storeProvider.getReadStore(version)
   }
 
@@ -480,13 +486,13 @@ object StateStore extends Logging {
       storeProviderId: StateStoreProviderId,
       keySchema: StructType,
       valueSchema: StructType,
-      indexOrdinal: Option[Int],
+      numColsPrefixKey: Int,
       version: Long,
       storeConf: StateStoreConf,
       hadoopConf: Configuration): StateStore = {
     require(version >= 0)
     val storeProvider = getStateStoreProvider(storeProviderId, keySchema, valueSchema,
-      indexOrdinal, storeConf, hadoopConf)
+      numColsPrefixKey, storeConf, hadoopConf)
     storeProvider.getStore(version)
   }
 
@@ -494,7 +500,7 @@ object StateStore extends Logging {
       storeProviderId: StateStoreProviderId,
       keySchema: StructType,
       valueSchema: StructType,
-      indexOrdinal: Option[Int],
+      numColsPrefixKey: Int,
       storeConf: StateStoreConf,
       hadoopConf: Configuration): StateStoreProvider = {
     loadedProviders.synchronized {
@@ -521,7 +527,7 @@ object StateStore extends Logging {
       val provider = loadedProviders.getOrElseUpdate(
         storeProviderId,
         StateStoreProvider.createAndInit(
-          storeProviderId, keySchema, valueSchema, indexOrdinal, storeConf, hadoopConf)
+          storeProviderId, keySchema, valueSchema, numColsPrefixKey, storeConf, hadoopConf)
       )
       val otherProviderIds = loadedProviders.keys.filter(_ != storeProviderId).toSeq
       val providerIdsToUnload = reportActiveStoreInstance(storeProviderId, otherProviderIds)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
index f21e2ff..fbe83ad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
@@ -74,7 +74,7 @@ class ReadStateStoreRDD[T: ClassTag, U: ClassTag](
     storeVersion: Long,
     keySchema: StructType,
     valueSchema: StructType,
-    indexOrdinal: Option[Int],
+    numColsPrefixKey: Int,
     sessionState: SessionState,
     @transient private val storeCoordinator: Option[StateStoreCoordinatorRef],
     extraOptions: Map[String, String] = Map.empty)
@@ -87,7 +87,7 @@ class ReadStateStoreRDD[T: ClassTag, U: ClassTag](
     val storeProviderId = getStateProviderId(partition)
 
     val store = StateStore.getReadOnly(
-      storeProviderId, keySchema, valueSchema, indexOrdinal, storeVersion,
+      storeProviderId, keySchema, valueSchema, numColsPrefixKey, storeVersion,
       storeConf, hadoopConfBroadcast.value.value)
     val inputIter = dataRDD.iterator(partition, ctxt)
     storeReadFunction(store, inputIter)
@@ -108,7 +108,7 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
     storeVersion: Long,
     keySchema: StructType,
     valueSchema: StructType,
-    indexOrdinal: Option[Int],
+    numColsPrefixKey: Int,
     sessionState: SessionState,
     @transient private val storeCoordinator: Option[StateStoreCoordinatorRef],
     extraOptions: Map[String, String] = Map.empty)
@@ -121,7 +121,7 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
     val storeProviderId = getStateProviderId(partition)
 
     val store = StateStore.get(
-      storeProviderId, keySchema, valueSchema, indexOrdinal, storeVersion,
+      storeProviderId, keySchema, valueSchema, numColsPrefixKey, storeVersion,
       storeConf, hadoopConfBroadcast.value.value)
     val inputIter = dataRDD.iterator(partition, ctxt)
     storeUpdateFunction(store, inputIter)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala
index 0496e47..36138f1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManager.scala
@@ -92,7 +92,7 @@ abstract class StreamingAggregationStateManagerBaseImpl(
 
   override def keys(store: ReadStateStore): Iterator[UnsafeRow] = {
     // discard and don't convert values to avoid computation
-    store.getRange(None, None).map(_.key)
+    store.iterator().map(_.key)
   }
 }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
index d342c83..f301d23 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
@@ -364,7 +364,7 @@ class SymmetricHashJoinStateManager(
       val storeProviderId = StateStoreProviderId(
         stateInfo.get, partitionId, getStateStoreName(joinSide, stateStoreType))
       val store = StateStore.get(
-        storeProviderId, keySchema, valueSchema, None,
+        storeProviderId, keySchema, valueSchema, numColsPrefixKey = 0,
         stateInfo.get.storeVersion, storeConf, hadoopConf)
       logInfo(s"Loaded store ${store.id}")
       store
@@ -410,7 +410,7 @@ class SymmetricHashJoinStateManager(
 
     def iterator: Iterator[KeyAndNumValues] = {
       val keyAndNumValues = new KeyAndNumValues()
-      stateStore.getRange(None, None).map { case pair =>
+      stateStore.iterator().map { pair =>
         keyAndNumValues.withNew(pair.key, pair.value.getLong(0))
       }
     }
@@ -605,7 +605,7 @@ class SymmetricHashJoinStateManager(
 
     def iterator: Iterator[KeyWithIndexAndValue] = {
       val keyWithIndexAndValue = new KeyWithIndexAndValue()
-      stateStore.getRange(None, None).map { pair =>
+      stateStore.iterator().map { pair =>
         val valuePair = valueRowConverter.convertValue(pair.value)
         keyWithIndexAndValue.withNew(
           keyRowGenerator(pair.key), pair.key.getLong(indexOrdinalInKeyWithIndexRow), valuePair)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
index fa89c50..01ff72b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
@@ -35,14 +35,14 @@ package object state {
         stateInfo: StatefulOperatorStateInfo,
         keySchema: StructType,
         valueSchema: StructType,
-        indexOrdinal: Option[Int])(
+        numColsPrefixKey: Int)(
         storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = {
 
       mapPartitionsWithStateStore(
         stateInfo,
         keySchema,
         valueSchema,
-        indexOrdinal,
+        numColsPrefixKey,
         sqlContext.sessionState,
         Some(sqlContext.streams.stateStoreCoordinator))(
         storeUpdateFunction)
@@ -53,7 +53,7 @@ package object state {
         stateInfo: StatefulOperatorStateInfo,
         keySchema: StructType,
         valueSchema: StructType,
-        indexOrdinal: Option[Int],
+        numColsPrefixKey: Int,
         sessionState: SessionState,
         storeCoordinator: Option[StateStoreCoordinatorRef],
         extraOptions: Map[String, String] = Map.empty)(
@@ -77,7 +77,7 @@ package object state {
         stateInfo.storeVersion,
         keySchema,
         valueSchema,
-        indexOrdinal,
+        numColsPrefixKey,
         sessionState,
         storeCoordinator,
         extraOptions)
@@ -88,7 +88,7 @@ package object state {
         stateInfo: StatefulOperatorStateInfo,
         keySchema: StructType,
         valueSchema: StructType,
-        indexOrdinal: Option[Int],
+        numColsPrefixKey: Int,
         sessionState: SessionState,
         storeCoordinator: Option[StateStoreCoordinatorRef],
         extraOptions: Map[String, String] = Map.empty)(
@@ -112,7 +112,7 @@ package object state {
         stateInfo.storeVersion,
         keySchema,
         valueSchema,
-        indexOrdinal,
+        numColsPrefixKey,
         sessionState,
         storeCoordinator,
         extraOptions)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index 22feff3..3f6a7ba 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -236,7 +236,7 @@ trait WatermarkSupport extends SparkPlan {
   protected def removeKeysOlderThanWatermark(store: StateStore): Unit = {
     if (watermarkPredicateForKeys.nonEmpty) {
       val numRemovedStateRows = longMetric("numRemovedStateRows")
-      store.getRange(None, None).foreach { rowPair =>
+      store.iterator().foreach { rowPair =>
         if (watermarkPredicateForKeys.get.eval(rowPair.key)) {
           store.remove(rowPair.key)
           numRemovedStateRows += 1
@@ -306,7 +306,7 @@ case class StateStoreRestoreExec(
       getStateInfo,
       keyExpressions.toStructType,
       stateManager.getStateValueSchema,
-      indexOrdinal = None,
+      numColsPrefixKey = 0,
       session.sessionState,
       Some(session.streams.stateStoreCoordinator)) { case (store, iter) =>
         val hasInput = iter.hasNext
@@ -368,7 +368,7 @@ case class StateStoreSaveExec(
       getStateInfo,
       keyExpressions.toStructType,
       stateManager.getStateValueSchema,
-      indexOrdinal = None,
+      numColsPrefixKey = 0,
       session.sessionState,
       Some(session.streams.stateStoreCoordinator)) { (store, iter) =>
         val numOutputRows = longMetric("numOutputRows")
@@ -530,7 +530,7 @@ case class StreamingDeduplicateExec(
       getStateInfo,
       keyExpressions.toStructType,
       child.output.toStructType,
-      indexOrdinal = None,
+      numColsPrefixKey = 0,
       session.sessionState,
       Some(session.streams.stateStoreCoordinator),
       // We won't check value row in state store since the value StreamingDeduplicateExec.EMPTY_ROW
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala
index 0e9d12d..8bba9b8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/streamingLimits.scala
@@ -52,7 +52,7 @@ case class StreamingGlobalLimitExec(
         getStateInfo,
         keySchema,
         valueSchema,
-        indexOrdinal = None,
+        numColsPrefixKey = 0,
         session.sessionState,
         Some(session.streams.stateStoreCoordinator)) { (store, iter) =>
       val key = UnsafeProjection.create(keySchema)(new GenericInternalRow(Array[Any](null)))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala
index 98586d6..e52ccd0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala
@@ -46,4 +46,8 @@ class MemoryStateStore extends StateStore() {
   override def metrics: StateStoreMetrics = new StateStoreMetrics(map.size, 0, Map.empty)
 
   override def hasCommitted: Boolean = true
+
+  override def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair] = {
+    throw new UnsupportedOperationException("Doesn't support prefix scan!")
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
index b9cc844..b91ed26 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
@@ -43,8 +43,8 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid
 
     val provider = newStoreProvider()
     val store = provider.getStore(0)
-    val keyRow = stringToRow("a")
-    val valueRow = intToRow(1)
+    val keyRow = dataToKeyRow("a", 0)
+    val valueRow = dataToValueRow(1)
     store.put(keyRow, valueRow)
     val iter = provider.rocksDB.iterator()
     assert(iter.hasNext)
@@ -76,7 +76,7 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid
 
       // Create state store in a task and get the RocksDBConf from the instantiated RocksDB instance
       val rocksDBConfInTask: RocksDBConf = testRDD.mapPartitionsWithStateStore[RocksDBConf](
-        spark.sqlContext, testStateInfo, testSchema, testSchema, None) {
+        spark.sqlContext, testStateInfo, testSchema, testSchema, 0) {
           (store: StateStore, _: Iterator[String]) =>
             // Use reflection to get RocksDB instance
             val dbInstanceMethod =
@@ -101,8 +101,8 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid
     val provider = newStoreProvider()
     val store = provider.getStore(0)
     // Verify state after updating
-    put(store, "a", 1)
-    assert(get(store, "a") === Some(1))
+    put(store, "a", 0, 1)
+    assert(get(store, "a", 0) === Some(1))
     assert(store.commit() === 1)
     assert(store.hasCommitted)
     val storeMetrics = store.metrics
@@ -118,29 +118,36 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid
   }
 
   def newStoreProvider(storeId: StateStoreId): RocksDBStateStoreProvider = {
-    val keySchema = StructType(Seq(StructField("key", StringType, true)))
-    val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
+    newStoreProvider(storeId, numColsPrefixKey = 0)
+  }
+
+  override def newStoreProvider(numPrefixCols: Int): RocksDBStateStoreProvider = {
+    newStoreProvider(StateStoreId(newDir(), Random.nextInt(), 0), numColsPrefixKey = numPrefixCols)
+  }
+
+  def newStoreProvider(
+      storeId: StateStoreId,
+      numColsPrefixKey: Int): RocksDBStateStoreProvider = {
     val provider = new RocksDBStateStoreProvider()
     provider.init(
-      storeId, keySchema, valueSchema, indexOrdinal = None, new StateStoreConf, new Configuration)
+      storeId, keySchema, valueSchema, numColsPrefixKey = numColsPrefixKey,
+      new StateStoreConf, new Configuration)
     provider
   }
 
-  override def getLatestData(storeProvider: RocksDBStateStoreProvider): Set[(String, Int)] = {
+  override def getLatestData(
+      storeProvider: RocksDBStateStoreProvider): Set[((String, Int), Int)] = {
     getData(storeProvider, version = -1)
   }
 
   override def getData(
       provider: RocksDBStateStoreProvider,
-      version: Int = -1): Set[(String, Int)] = {
+      version: Int = -1): Set[((String, Int), Int)] = {
     val reloadedProvider = newStoreProvider(provider.stateStoreId)
     val versionToRead = if (version < 0) reloadedProvider.latestVersion else version
-    reloadedProvider.getStore(versionToRead).iterator().map(rowsToStringInt).toSet
+    reloadedProvider.getStore(versionToRead).iterator().map(rowPairToDataPair).toSet
   }
 
-  override protected val keySchema = StructType(Seq(StructField("key", StringType, true)))
-  override protected val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
-
   override def newStoreProvider(
     minDeltasForSnapshot: Int,
     numOfVersToRetainInMemory: Int): RocksDBStateStoreProvider = newStoreProvider()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
index 378aa1d..6bb8ebe 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala
@@ -32,7 +32,6 @@ import org.apache.spark.sql.LocalSparkSession._
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.util.quietly
 import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo
-import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
 import org.apache.spark.util.{CompletionIterator, Utils}
 
 class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll {
@@ -41,8 +40,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
 
   private val sparkConf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName)
   private val tempDir = Files.createTempDirectory("StateStoreRDDSuite").toString
-  private val keySchema = StructType(Seq(StructField("key", StringType, true)))
-  private val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
+  private val keySchema = StateStoreTestsHelper.keySchema
+  private val valueSchema = StateStoreTestsHelper.valueSchema
 
   after {
     StateStore.stop()
@@ -59,19 +58,19 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
   test("versioning and immutability") {
     withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
       val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString
-      val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
-            spark.sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)(
-            increment)
-      assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+      val rdd1 = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("a", 0)))
+        .mapPartitionsWithStateStore(spark.sqlContext, operatorStateInfo(path, version = 0),
+          keySchema, valueSchema, numColsPrefixKey = 0)(increment)
+      assert(rdd1.collect().toSet === Set(("a", 0) -> 2, ("b", 0) -> 1))
 
       // Generate next version of stores
-      val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore(
-        spark.sqlContext, operatorStateInfo(path, version = 1), keySchema, valueSchema, None)(
-        increment)
-      assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
+      val rdd2 = makeRDD(spark.sparkContext, Seq(("a", 0), ("c", 0)))
+        .mapPartitionsWithStateStore(spark.sqlContext, operatorStateInfo(path, version = 1),
+          keySchema, valueSchema, numColsPrefixKey = 0)(increment)
+      assert(rdd2.collect().toSet === Set(("a", 0) -> 3, ("b", 0) -> 1, ("c", 0) -> 1))
 
       // Make sure the previous RDD still has the same data.
-      assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+      assert(rdd1.collect().toSet === Set(("a", 0) -> 2, ("b", 0) -> 1))
     }
   }
 
@@ -80,24 +79,24 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
 
     def makeStoreRDD(
         spark: SparkSession,
-        seq: Seq[String],
-        storeVersion: Int): RDD[(String, Int)] = {
+        seq: Seq[(String, Int)],
+        storeVersion: Int): RDD[((String, Int), Int)] = {
       implicit val sqlContext = spark.sqlContext
-      makeRDD(spark.sparkContext, Seq("a")).mapPartitionsWithStateStore(
+      makeRDD(spark.sparkContext, Seq(("a", 0))).mapPartitionsWithStateStore(
         sqlContext, operatorStateInfo(path, version = storeVersion),
-        keySchema, valueSchema, None)(increment)
+        keySchema, valueSchema, numColsPrefixKey = 0)(increment)
     }
 
     // Generate RDDs and state store data
     withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
       for (i <- 1 to 20) {
-        require(makeStoreRDD(spark, Seq("a"), i - 1).collect().toSet === Set("a" -> i))
+        require(makeStoreRDD(spark, Seq(("a", 0)), i - 1).collect().toSet === Set(("a", 0) -> i))
       }
     }
 
     // With a new context, try using the earlier state store data
     withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark =>
-      assert(makeStoreRDD(spark, Seq("a"), 20).collect().toSet === Set("a" -> 21))
+      assert(makeStoreRDD(spark, Seq(("a", 0)), 20).collect().toSet === Set(("a", 0) -> 21))
     }
   }
 
@@ -108,43 +107,48 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
       val opId = 0
 
       // Returns an iterator of the incremented value made into the store
-      def iteratorOfPuts(store: StateStore, iter: Iterator[String]): Iterator[(String, Int)] = {
-        val resIterator = iter.map { s =>
-          val key = stringToRow(s)
-          val oldValue = Option(store.get(key)).map(rowToInt).getOrElse(0)
+      def iteratorOfPuts(
+          store: StateStore,
+          iter: Iterator[(String, Int)]): Iterator[((String, Int), Int)] = {
+        val resIterator = iter.map { case (s, i) =>
+          val key = dataToKeyRow(s, i)
+          val oldValue = Option(store.get(key)).map(valueRowToData).getOrElse(0)
           val newValue = oldValue + 1
-          store.put(key, intToRow(newValue))
-          (s, newValue)
+          store.put(key, dataToValueRow(newValue))
+          ((s, i), newValue)
         }
-        CompletionIterator[(String, Int), Iterator[(String, Int)]](resIterator, {
+        CompletionIterator[((String, Int), Int), Iterator[((String, Int), Int)]](resIterator, {
           store.commit()
         })
       }
 
       def iteratorOfGets(
           store: StateStore,
-          iter: Iterator[String]): Iterator[(String, Option[Int])] = {
-        iter.map { s =>
-          val key = stringToRow(s)
-          val value = Option(store.get(key)).map(rowToInt)
-          (s, value)
+          iter: Iterator[(String, Int)]): Iterator[((String, Int), Option[Int])] = {
+        iter.map { case (s, i) =>
+          val key = dataToKeyRow(s, i)
+          val value = Option(store.get(key)).map(valueRowToData)
+          ((s, i), value)
         }
       }
 
-      val rddOfGets1 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore(
-        spark.sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)(
-        iteratorOfGets)
-      assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None))
-
-      val rddOfPuts = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
-        sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)(
-        iteratorOfPuts)
-      assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1))
-
-      val rddOfGets2 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore(
-        sqlContext, operatorStateInfo(path, version = 1), keySchema, valueSchema, None)(
-        iteratorOfGets)
-      assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> Some(1), "c" -> None))
+      val rddOfGets1 = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("c", 0)))
+        .mapPartitionsWithStateStore(spark.sqlContext, operatorStateInfo(path, version = 0),
+          keySchema, valueSchema, numColsPrefixKey = 0)(iteratorOfGets)
+      assert(rddOfGets1.collect().toSet ===
+        Set(("a", 0) -> None, ("b", 0) -> None, ("c", 0) -> None))
+
+      val rddOfPuts = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("a", 0)))
+        .mapPartitionsWithStateStore(sqlContext, operatorStateInfo(path, version = 0),
+          keySchema, valueSchema, numColsPrefixKey = 0)(iteratorOfPuts)
+      assert(rddOfPuts.collect().toSet ===
+        Set(("a", 0) -> 1, ("a", 0) -> 2, ("b", 0) -> 1))
+
+      val rddOfGets2 = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("c", 0)))
+        .mapPartitionsWithStateStore(sqlContext, operatorStateInfo(path, version = 1),
+          keySchema, valueSchema, numColsPrefixKey = 0)(iteratorOfGets)
+      assert(rddOfGets2.collect().toSet ===
+        Set(("a", 0) -> Some(2), ("b", 0) -> Some(1), ("c", 0) -> None))
     }
   }
 
@@ -166,9 +170,9 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
           coordinatorRef.getLocation(storeProviderId1) ===
             Some(ExecutorCacheTaskLocation("host1", "exec1").toString))
 
-        val rdd = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
-          sqlContext, operatorStateInfo(path, queryRunId = queryRunId),
-          keySchema, valueSchema, None)(increment)
+        val rdd = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("a", 0)))
+          .mapPartitionsWithStateStore(sqlContext, operatorStateInfo(path, queryRunId = queryRunId),
+          keySchema, valueSchema, numColsPrefixKey = 0)(increment)
         require(rdd.partitions.length === 2)
 
         assert(
@@ -194,22 +198,24 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
         implicit val sqlContext = spark.sqlContext
         val path = Utils.createDirectory(tempDir, Random.nextFloat.toString).toString
         val opId = 0
-        val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore(
-          sqlContext, operatorStateInfo(path, version = 0), keySchema, valueSchema, None)(increment)
-        assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+        val rdd1 = makeRDD(spark.sparkContext, Seq(("a", 0), ("b", 0), ("a", 0)))
+          .mapPartitionsWithStateStore(sqlContext, operatorStateInfo(path, version = 0),
+            keySchema, valueSchema, numColsPrefixKey = 0)(increment)
+        assert(rdd1.collect().toSet === Set(("a", 0) -> 2, ("b", 0) -> 1))
 
         // Generate next version of stores
-        val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore(
-          sqlContext, operatorStateInfo(path, version = 1), keySchema, valueSchema, None)(increment)
-        assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1))
+        val rdd2 = makeRDD(spark.sparkContext, Seq(("a", 0), ("c", 0)))
+          .mapPartitionsWithStateStore(sqlContext, operatorStateInfo(path, version = 1),
+            keySchema, valueSchema, numColsPrefixKey = 0)(increment)
+        assert(rdd2.collect().toSet === Set(("a", 0) -> 3, ("b", 0) -> 1, ("c", 0) -> 1))
 
         // Make sure the previous RDD still has the same data.
-        assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1))
+        assert(rdd1.collect().toSet === Set(("a", 0) -> 2, ("b", 0) -> 1))
       }
     }
   }
 
-  private def makeRDD(sc: SparkContext, seq: Seq[String]): RDD[String] = {
+  private def makeRDD(sc: SparkContext, seq: Seq[(String, Int)]): RDD[(String, Int)] = {
     sc.makeRDD(seq, 2).groupBy(x => x).flatMap(_._2)
   }
 
@@ -220,13 +226,13 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn
     StatefulOperatorStateInfo(path, queryRunId, operatorId = 0, version, numPartitions = 5)
   }
 
-  private val increment = (store: StateStore, iter: Iterator[String]) => {
-    iter.foreach { s =>
-      val key = stringToRow(s)
-      val oldValue = Option(store.get(key)).map(rowToInt).getOrElse(0)
-      store.put(key, intToRow(oldValue + 1))
+  private val increment = (store: StateStore, iter: Iterator[(String, Int)]) => {
+    iter.foreach { case (s, i) =>
+      val key = dataToKeyRow(s, i)
+      val oldValue = Option(store.get(key)).map(valueRowToData).getOrElse(0)
+      store.put(key, dataToValueRow(oldValue + 1))
     }
     store.commit()
-    store.iterator().map(rowsToStringInt)
+    store.iterator().map(rowPairToDataPair)
   }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index 2990860..05772cd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -51,9 +51,6 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
   import StateStoreTestsHelper._
   import StateStoreCoordinatorSuite._
 
-  override val keySchema = StructType(Seq(StructField("key", StringType, true)))
-  override val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
-
   before {
     StateStore.stop()
     require(!StateStore.isMaintenanceRunning)
@@ -71,27 +68,27 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
 
     // commit the ver 1 : cache will have one element
     currentVersion = incrementVersion(provider, currentVersion)
-    assert(getLatestData(provider) === Set("a" -> 1))
+    assert(getLatestData(provider) === Set(("a", 0) -> 1))
     var loadedMaps = provider.getLoadedMaps()
     checkLoadedVersions(loadedMaps, count = 1, earliestKey = 1, latestKey = 1)
-    checkVersion(loadedMaps, 1, Map("a" -> 1))
+    checkVersion(loadedMaps, 1, Map(("a", 0) -> 1))
 
     // commit the ver 2 : cache will have two elements
     currentVersion = incrementVersion(provider, currentVersion)
-    assert(getLatestData(provider) === Set("a" -> 2))
+    assert(getLatestData(provider) === Set(("a", 0) -> 2))
     loadedMaps = provider.getLoadedMaps()
     checkLoadedVersions(loadedMaps, count = 2, earliestKey = 2, latestKey = 1)
-    checkVersion(loadedMaps, 2, Map("a" -> 2))
-    checkVersion(loadedMaps, 1, Map("a" -> 1))
+    checkVersion(loadedMaps, 2, Map(("a", 0) -> 2))
+    checkVersion(loadedMaps, 1, Map(("a", 0) -> 1))
 
     // commit the ver 3 : cache has already two elements and adding ver 3 incurs exceeding cache,
     // and ver 3 will be added but ver 1 will be evicted
     currentVersion = incrementVersion(provider, currentVersion)
-    assert(getLatestData(provider) === Set("a" -> 3))
+    assert(getLatestData(provider) === Set(("a", 0) -> 3))
     loadedMaps = provider.getLoadedMaps()
     checkLoadedVersions(loadedMaps, count = 2, earliestKey = 3, latestKey = 2)
-    checkVersion(loadedMaps, 3, Map("a" -> 3))
-    checkVersion(loadedMaps, 2, Map("a" -> 2))
+    checkVersion(loadedMaps, 3, Map(("a", 0) -> 3))
+    checkVersion(loadedMaps, 2, Map(("a", 0) -> 2))
   }
 
   test("failure after committing with MAX_BATCHES_TO_RETAIN_IN_MEMORY set to 1") {
@@ -102,20 +99,20 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
 
     // commit the ver 1 : cache will have one element
     currentVersion = incrementVersion(provider, currentVersion)
-    assert(getLatestData(provider) === Set("a" -> 1))
+    assert(getLatestData(provider) === Set(("a", 0) -> 1))
     var loadedMaps = provider.getLoadedMaps()
     checkLoadedVersions(loadedMaps, count = 1, earliestKey = 1, latestKey = 1)
-    checkVersion(loadedMaps, 1, Map("a" -> 1))
+    checkVersion(loadedMaps, 1, Map(("a", 0) -> 1))
 
     // commit the ver 2 : cache has already one elements and adding ver 2 incurs exceeding cache,
     // and ver 2 will be added but ver 1 will be evicted
     // this fact ensures cache miss will occur when this partition succeeds commit
     // but there's a failure afterwards so have to reprocess previous batch
     currentVersion = incrementVersion(provider, currentVersion)
-    assert(getLatestData(provider) === Set("a" -> 2))
+    assert(getLatestData(provider) === Set(("a", 0) -> 2))
     loadedMaps = provider.getLoadedMaps()
     checkLoadedVersions(loadedMaps, count = 1, earliestKey = 2, latestKey = 2)
-    checkVersion(loadedMaps, 2, Map("a" -> 2))
+    checkVersion(loadedMaps, 2, Map(("a", 0) -> 2))
 
     // suppose there has been failure after committing, and it decided to reprocess previous batch
     currentVersion = 1
@@ -123,15 +120,15 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     // committing to existing version which is committed partially but abandoned globally
     val store = provider.getStore(currentVersion)
     // negative value to represent reprocessing
-    put(store, "a", -2)
+    put(store, "a", 0, -2)
     store.commit()
     currentVersion += 1
 
     // make sure newly committed version is reflected to the cache (overwritten)
-    assert(getLatestData(provider) === Set("a" -> -2))
+    assert(getLatestData(provider) === Set(("a", 0) -> -2))
     loadedMaps = provider.getLoadedMaps()
     checkLoadedVersions(loadedMaps, count = 1, earliestKey = 2, latestKey = 2)
-    checkVersion(loadedMaps, 2, Map("a" -> -2))
+    checkVersion(loadedMaps, 2, Map(("a", 0) -> -2))
   }
 
   test("no cache data with MAX_BATCHES_TO_RETAIN_IN_MEMORY set to 0") {
@@ -142,13 +139,13 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
 
     // commit the ver 1 : never cached
     currentVersion = incrementVersion(provider, currentVersion)
-    assert(getLatestData(provider) === Set("a" -> 1))
+    assert(getLatestData(provider) === Set(("a", 0) -> 1))
     var loadedMaps = provider.getLoadedMaps()
     assert(loadedMaps.size() === 0)
 
     // commit the ver 2 : never cached
     currentVersion = incrementVersion(provider, currentVersion)
-    assert(getLatestData(provider) === Set("a" -> 2))
+    assert(getLatestData(provider) === Set(("a", 0) -> 2))
     loadedMaps = provider.getLoadedMaps()
     assert(loadedMaps.size() === 0)
   }
@@ -158,19 +155,19 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
 
     for (i <- 1 to 20) {
       val store = provider.getStore(i - 1)
-      put(store, "a", i)
+      put(store, "a", 0, i)
       store.commit()
       provider.doMaintenance() // do cleanup
     }
     require(
-      rowsToSet(provider.latestIterator()) === Set("a" -> 20),
+      rowPairsToDataSet(provider.latestIterator()) === Set(("a", 0) -> 20),
       "store not updated correctly")
 
     assert(!fileExists(provider, version = 1, isSnapshot = false)) // first file should be deleted
 
     // last couple of versions should be retrievable
-    assert(getData(provider, 20) === Set("a" -> 20))
-    assert(getData(provider, 19) === Set("a" -> 19))
+    assert(getData(provider, 20) === Set(("a", 0) -> 20))
+    assert(getData(provider, 19) === Set(("a", 0) -> 19))
   }
 
   testQuietly("SPARK-19677: Committing a delta file atop an existing one should not fail on HDFS") {
@@ -192,7 +189,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     val provider = newStoreProvider(opId = Random.nextInt, partition = 0, minDeltasForSnapshot = 5)
     for (i <- 1 to 6) {
       val store = provider.getStore(i - 1)
-      put(store, "a", i)
+      put(store, "a", 0, i)
       store.commit()
       provider.doMaintenance() // do cleanup
     }
@@ -200,14 +197,14 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
       fileExists(provider, version, isSnapshot = true)).getOrElse(fail("snapshot file not found"))
 
     // Corrupt snapshot file and verify that it throws error
-    assert(getData(provider, snapshotVersion) === Set("a" -> snapshotVersion))
+    assert(getData(provider, snapshotVersion) === Set(("a", 0) -> snapshotVersion))
     corruptFile(provider, snapshotVersion, isSnapshot = true)
     intercept[Exception] {
       getData(provider, snapshotVersion)
     }
 
     // Corrupt delta file and verify that it throws error
-    assert(getData(provider, snapshotVersion - 1) === Set("a" -> (snapshotVersion - 1)))
+    assert(getData(provider, snapshotVersion - 1) === Set(("a", 0) -> (snapshotVersion - 1)))
     corruptFile(provider, snapshotVersion - 1, isSnapshot = false)
     intercept[Exception] {
       getData(provider, snapshotVersion - 1)
@@ -231,7 +228,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     val store = provider.getStore(0)
     val noDataMemoryUsed = getSizeOfStateForCurrentVersion(store.metrics)
 
-    put(store, "a", 1)
+    put(store, "a", 0, 1)
     store.commit()
     assert(getSizeOfStateForCurrentVersion(store.metrics) > noDataMemoryUsed)
   }
@@ -261,9 +258,9 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
 
     def generateStoreVersions(): Unit = {
       for (i <- 1 to 20) {
-        val store = StateStore.get(storeProviderId1, keySchema, valueSchema, None,
+        val store = StateStore.get(storeProviderId1, keySchema, valueSchema, numColsPrefixKey = 0,
           latestStoreVersion, storeConf, hadoopConf)
-        put(store, "a", i)
+        put(store, "a", 0, i)
         store.commit()
         latestStoreVersion += 1
       }
@@ -311,7 +308,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
           }
 
           // Reload the store and verify
-          StateStore.get(storeProviderId1, keySchema, valueSchema, indexOrdinal = None,
+          StateStore.get(storeProviderId1, keySchema, valueSchema, numColsPrefixKey = 0,
             latestStoreVersion, storeConf, hadoopConf)
           assert(StateStore.isLoaded(storeProviderId1))
 
@@ -323,7 +320,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
           }
 
           // Reload the store and verify
-          StateStore.get(storeProviderId1, keySchema, valueSchema, indexOrdinal = None,
+          StateStore.get(storeProviderId1, keySchema, valueSchema, numColsPrefixKey = 0,
             latestStoreVersion, storeConf, hadoopConf)
           assert(StateStore.isLoaded(storeProviderId1))
 
@@ -331,7 +328,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
           // then this executor should unload inactive instances immediately.
           coordinatorRef
             .reportActiveInstance(storeProviderId1, "other-host", "other-exec", Seq.empty)
-          StateStore.get(storeProviderId2, keySchema, valueSchema, indexOrdinal = None,
+          StateStore.get(storeProviderId2, keySchema, valueSchema, numColsPrefixKey = 0,
             0, storeConf, hadoopConf)
           assert(!StateStore.isLoaded(storeProviderId1))
           assert(StateStore.isLoaded(storeProviderId2))
@@ -354,9 +351,9 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     var currentVersion = 0
 
     currentVersion = updateVersionTo(provider, currentVersion, 2)
-    require(getLatestData(provider) === Set("a" -> 2))
+    require(getLatestData(provider) === Set(("a", 0) -> 2))
     provider.doMaintenance()               // should not generate snapshot files
-    assert(getLatestData(provider) === Set("a" -> 2))
+    assert(getLatestData(provider) === Set(("a", 0) -> 2))
 
     for (i <- 1 to currentVersion) {
       assert(fileExists(provider, i, isSnapshot = false))  // all delta files present
@@ -365,22 +362,22 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
 
     // After version 6, snapshotting should generate one snapshot file
     currentVersion = updateVersionTo(provider, currentVersion, 6)
-    require(getLatestData(provider) === Set("a" -> 6), "store not updated correctly")
+    require(getLatestData(provider) === Set(("a", 0) -> 6), "store not updated correctly")
     provider.doMaintenance()       // should generate snapshot files
 
     val snapshotVersion = (0 to 6).find(version => fileExists(provider, version, isSnapshot = true))
     assert(snapshotVersion.nonEmpty, "snapshot file not generated")
     deleteFilesEarlierThanVersion(provider, snapshotVersion.get)
     assert(
-      getData(provider, snapshotVersion.get) === Set("a" -> snapshotVersion.get),
+      getData(provider, snapshotVersion.get) === Set(("a", 0) -> snapshotVersion.get),
       "snapshotting messed up the data of the snapshotted version")
     assert(
-      getLatestData(provider) === Set("a" -> 6),
+      getLatestData(provider) === Set(("a", 0) -> 6),
       "snapshotting messed up the data of the final version")
 
     // After version 20, snapshotting should generate newer snapshot files
     currentVersion = updateVersionTo(provider, currentVersion, 20)
-    require(getLatestData(provider) === Set("a" -> 20), "store not updated correctly")
+    require(getLatestData(provider) === Set(("a", 0) -> 20), "store not updated correctly")
     provider.doMaintenance()       // do snapshot
 
     val latestSnapshotVersion = (0 to 20).filter(version =>
@@ -389,7 +386,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     assert(latestSnapshotVersion.get > snapshotVersion.get, "newer snapshot not generated")
 
     deleteFilesEarlierThanVersion(provider, latestSnapshotVersion.get)
-    assert(getLatestData(provider) === Set("a" -> 20), "snapshotting messed up the data")
+    assert(getLatestData(provider) === Set(("a", 0) -> 20), "snapshotting messed up the data")
   }
 
   testQuietly("SPARK-18342: commit fails when rename fails") {
@@ -400,7 +397,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     val provider = newStoreProvider(
       opId = Random.nextInt, partition = 0, dir = dir, hadoopConf = conf)
     val store = provider.getStore(0)
-    put(store, "a", 0)
+    put(store, "a", 0, 0)
     val e = intercept[IllegalStateException](store.commit())
     assert(e.getCause.getMessage.contains("Failed to rename"))
   }
@@ -434,11 +431,12 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     // Getting the store should not create temp file
     val store0 = shouldNotCreateTempFile {
       StateStore.get(
-        storeId, keySchema, valueSchema, indexOrdinal = None, version = 0, storeConf, hadoopConf)
+        storeId, keySchema, valueSchema, numColsPrefixKey = 0,
+        version = 0, storeConf, hadoopConf)
     }
 
     // Put should create a temp file
-    put(store0, "a", 1)
+    put(store0, "a", 0, 1)
     assert(numTempFiles === 1)
     assert(numDeltaFiles === 0)
 
@@ -450,9 +448,10 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     // Remove should create a temp file
     val store1 = shouldNotCreateTempFile {
       StateStore.get(
-        storeId, keySchema, valueSchema, indexOrdinal = None, version = 1, storeConf, hadoopConf)
+        storeId, keySchema, valueSchema, numColsPrefixKey = 0,
+        version = 1, storeConf, hadoopConf)
     }
-    remove(store1, _ == "a")
+    remove(store1, _._1 == "a")
     assert(numTempFiles === 1)
     assert(numDeltaFiles === 1)
 
@@ -464,7 +463,8 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     // Commit without any updates should create a delta file
     val store2 = shouldNotCreateTempFile {
       StateStore.get(
-        storeId, keySchema, valueSchema, indexOrdinal = None, version = 2, storeConf, hadoopConf)
+        storeId, keySchema, valueSchema, numColsPrefixKey = 0,
+        version = 2, storeConf, hadoopConf)
     }
     store2.commit()
     assert(numTempFiles === 0)
@@ -535,7 +535,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     CreateAtomicTestManager.shouldFailInCreateAtomic = false
     for (version <- 1 to 10) {
       val store = provider.getStore(version - 1)
-      put(store, version.toString, version) // update "1" -> 1, "2" -> 2, ...
+      put(store, version.toString, 0, version) // update "1" -> 1, "2" -> 2, ...
       store.commit()
     }
     val version10Data = (1L to 10).map(_.toString).map(x => x -> x).toSet
@@ -544,7 +544,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     val store = provider.getStore(10)
     // Fail commit for next version and verify that reloading resets the files
     CreateAtomicTestManager.shouldFailInCreateAtomic = true
-    put(store, "11", 11)
+    put(store, "11", 0, 11)
     val e = intercept[IllegalStateException] { quietly { store.commit() } }
     assert(e.getCause.isInstanceOf[IOException])
     CreateAtomicTestManager.shouldFailInCreateAtomic = false
@@ -552,7 +552,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     // Abort commit for next version and verify that reloading resets the files
     CreateAtomicTestManager.cancelCalledInCreateAtomic = false
     val store2 = provider.getStore(10)
-    put(store2, "11", 11)
+    put(store2, "11", 0, 11)
     store2.abort()
     assert(CreateAtomicTestManager.cancelCalledInCreateAtomic)
   }
@@ -592,13 +592,13 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     assert(initialLoadedMapSize >= 0)
     assertCacheHitAndMiss(store.metrics, expectedCacheHitCount = 0, expectedCacheMissCount = 0)
 
-    put(store, "a", 1)
+    put(store, "a", 0, 1)
     assert(store.metrics.numKeys === 1)
 
-    put(store, "b", 2)
-    put(store, "aa", 3)
+    put(store, "b", 0, 2)
+    put(store, "aa", 0, 3)
     assert(store.metrics.numKeys === 3)
-    remove(store, _.startsWith("a"))
+    remove(store, _._1.startsWith("a"))
     assert(store.metrics.numKeys === 1)
     assert(store.commit() === 1)
 
@@ -612,7 +612,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     assert(!storeV2.hasCommitted)
     assert(storeV2.metrics.numKeys === 1)
 
-    put(storeV2, "cc", 4)
+    put(storeV2, "cc", 0, 4)
     assert(storeV2.metrics.numKeys === 2)
     assert(storeV2.commit() === 2)
 
@@ -628,7 +628,6 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
     val reloadedStore = reloadedProvider.getStore(1)
     assert(reloadedStore.metrics.numKeys === 1)
 
-    assert(getLoadedMapSizeMetric(reloadedStore.metrics) === loadedMapSizeForVersion1)
     assertCacheHitAndMiss(reloadedStore.metrics, expectedCacheHitCount = 0,
       expectedCacheMissCount = 1)
 
@@ -657,18 +656,19 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
       numOfVersToRetainInMemory = numOfVersToRetainInMemory)
   }
 
-  override def getLatestData(storeProvider: HDFSBackedStateStoreProvider): Set[(String, Int)] = {
+  override def getLatestData(
+      storeProvider: HDFSBackedStateStoreProvider): Set[((String, Int), Int)] = {
     getData(storeProvider, -1)
   }
 
   override def getData(
     provider: HDFSBackedStateStoreProvider,
-    version: Int): Set[(String, Int)] = {
+    version: Int): Set[((String, Int), Int)] = {
     val reloadedProvider = newStoreProvider(provider.stateStoreId)
     if (version < 0) {
-      reloadedProvider.latestIterator().map(rowsToStringInt).toSet
+      reloadedProvider.latestIterator().map(rowPairToDataPair).toSet
     } else {
-      reloadedProvider.getStore(version).iterator().map(rowsToStringInt).toSet
+      reloadedProvider.getStore(version).iterator().map(rowPairToDataPair).toSet
     }
   }
 
@@ -686,6 +686,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
   def newStoreProvider(
       opId: Long,
       partition: Int,
+      numColsPrefixKey: Int = 0,
       dir: String = newDir(),
       minDeltasForSnapshot: Int = SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.defaultValue.get,
       numOfVersToRetainInMemory: Int = SQLConf.MAX_BATCHES_TO_RETAIN_IN_MEMORY.defaultValue.get,
@@ -696,14 +697,18 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
       StateStoreId(dir, opId, partition),
       keySchema,
       valueSchema,
-      indexOrdinal = None,
+      numColsPrefixKey = numColsPrefixKey,
       new StateStoreConf(sqlConf),
       hadoopConf)
     provider
   }
 
+  override def newStoreProvider(numPrefixCols: Int): HDFSBackedStateStoreProvider = {
+    newStoreProvider(opId = Random.nextInt(), partition = 0, numColsPrefixKey = numPrefixCols)
+  }
+
   def checkLoadedVersions(
-      loadedMaps: util.SortedMap[Long, ProviderMapType],
+      loadedMaps: util.SortedMap[Long, HDFSBackedStateStoreMap],
       count: Int,
       earliestKey: Long,
       latestKey: Long): Unit = {
@@ -713,11 +718,11 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
   }
 
   def checkVersion(
-      loadedMaps: util.SortedMap[Long, ProviderMapType],
+      loadedMaps: util.SortedMap[Long, HDFSBackedStateStoreMap],
       version: Long,
-      expectedData: Map[String, Int]): Unit = {
-    val originValueMap = loadedMaps.get(version).asScala.map { entry =>
-      rowToString(entry._1) -> rowToInt(entry._2)
+      expectedData: Map[(String, Int), Int]): Unit = {
+    val originValueMap = loadedMaps.get(version).iterator().map { entry =>
+      keyRowToData(entry.key) -> valueRowToData(entry.value)
     }.toMap
 
     assert(originValueMap === expectedData)
@@ -741,10 +746,9 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
   import StateStoreTestsHelper._
 
   type MapType = mutable.HashMap[UnsafeRow, UnsafeRow]
-  type ProviderMapType = java.util.concurrent.ConcurrentHashMap[UnsafeRow, UnsafeRow]
 
-  protected val keySchema: StructType
-  protected val valueSchema: StructType
+  protected val keySchema: StructType = StateStoreTestsHelper.keySchema
+  protected val valueSchema: StructType = StateStoreTestsHelper.valueSchema
 
   testWithAllCodec("get, put, remove, commit, and all data iterator") {
     val provider = newStoreProvider()
@@ -754,26 +758,26 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
 
     val store = provider.getStore(0)
     assert(!store.hasCommitted)
-    assert(get(store, "a") === None)
+    assert(get(store, "a", 0) === None)
     assert(store.iterator().isEmpty)
     assert(store.metrics.numKeys === 0)
 
     // Verify state after updating
-    put(store, "a", 1)
-    assert(get(store, "a") === Some(1))
+    put(store, "a", 0, 1)
+    assert(get(store, "a", 0) === Some(1))
 
     assert(store.iterator().nonEmpty)
     assert(getLatestData(provider).isEmpty)
 
     // Make updates, commit and then verify state
-    put(store, "b", 2)
-    put(store, "aa", 3)
-    remove(store, _.startsWith("a"))
+    put(store, "b", 0, 2)
+    put(store, "aa", 0, 3)
+    remove(store, _._1.startsWith("a"))
     assert(store.commit() === 1)
 
     assert(store.hasCommitted)
-    assert(rowsToSet(store.iterator()) === Set("b" -> 2))
-    assert(getLatestData(provider) === Set("b" -> 2))
+    assert(rowPairsToDataSet(store.iterator()) === Set(("b", 0) -> 2))
+    assert(getLatestData(provider) === Set(("b", 0) -> 2))
 
     // Trying to get newer versions should fail
     intercept[Exception] {
@@ -786,11 +790,40 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
     // New updates to the reloaded store with new version, and does not change old version
     val reloadedProvider = newStoreProvider(store.id)
     val reloadedStore = reloadedProvider.getStore(1)
-    put(reloadedStore, "c", 4)
+    put(reloadedStore, "c", 0, 4)
     assert(reloadedStore.commit() === 2)
-    assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4))
-    assert(getLatestData(provider) === Set("b" -> 2, "c" -> 4))
-    assert(getData(provider, version = 1) === Set("b" -> 2))
+    assert(rowPairsToDataSet(reloadedStore.iterator()) === Set(("b", 0) -> 2, ("c", 0) -> 4))
+    assert(getLatestData(provider) === Set(("b", 0) -> 2, ("c", 0) -> 4))
+    assert(getData(provider, version = 1) === Set(("b", 0) -> 2))
+  }
+
+  testWithAllCodec("prefix scan") {
+    val provider = newStoreProvider(numPrefixCols = 1)
+
+    // Verify state before starting a new set of updates
+    assert(getLatestData(provider).isEmpty)
+
+    val store = provider.getStore(0)
+
+    val key1 = Seq("a", "b", "c")
+    val key2 = Seq(1, 2, 3)
+    val keys = for (k1 <- key1; k2 <- key2) yield (k1, k2)
+
+    val randomizedKeys = scala.util.Random.shuffle(keys.toList)
+
+    randomizedKeys.foreach { case (key1, key2) =>
+      put(store, key1, key2, key2)
+    }
+
+    key1.foreach { k1 =>
+      val keyValueSet = store.prefixScan(dataToPrefixKeyRow(k1)).map { pair =>
+        rowPairToDataPair(pair.withRows(pair.key.copy(), pair.value.copy()))
+      }.toSet
+
+      assert(keyValueSet === key2.map(k2 => ((k1, k2), k2)).toSet)
+    }
+
+    assert(store.prefixScan(dataToPrefixKeyRow("non-exist")).isEmpty)
   }
 
   testWithAllCodec("numKeys metrics") {
@@ -800,21 +833,23 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
     assert(getLatestData(provider).isEmpty)
 
     val store = provider.getStore(0)
-    put(store, "a", 1)
-    put(store, "b", 2)
-    put(store, "c", 3)
-    put(store, "d", 4)
-    put(store, "e", 5)
+    put(store, "a", 0, 1)
+    put(store, "b", 0, 2)
+    put(store, "c", 0, 3)
+    put(store, "d", 0, 4)
+    put(store, "e", 0, 5)
     assert(store.commit() === 1)
     assert(store.metrics.numKeys === 5)
-    assert(rowsToSet(store.iterator()) === Set("a" -> 1, "b" -> 2, "c" -> 3, "d" -> 4, "e" -> 5))
+    assert(rowPairsToDataSet(store.iterator()) ===
+      Set(("a", 0) -> 1, ("b", 0) -> 2, ("c", 0) -> 3, ("d", 0) -> 4, ("e", 0) -> 5))
 
     val reloadedProvider = newStoreProvider(store.id)
     val reloadedStore = reloadedProvider.getStore(1)
-    remove(reloadedStore, _ == "b")
+    remove(reloadedStore, _._1 == "b")
     assert(reloadedStore.commit() === 2)
     assert(reloadedStore.metrics.numKeys === 4)
-    assert(rowsToSet(reloadedStore.iterator()) === Set("a" -> 1, "c" -> 3, "d" -> 4, "e" -> 5))
+    assert(rowPairsToDataSet(reloadedStore.iterator()) ===
+      Set(("a", 0) -> 1, ("c", 0) -> 3, ("d", 0) -> 4, ("e", 0) -> 5))
   }
 
   testWithAllCodec("removing while iterating") {
@@ -823,32 +858,32 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
     // Verify state before starting a new set of updates
     assert(getLatestData(provider).isEmpty)
     val store = provider.getStore(0)
-    put(store, "a", 1)
-    put(store, "b", 2)
+    put(store, "a", 0, 1)
+    put(store, "b", 0, 2)
 
     // Updates should work while iterating of filtered entries
-    val filtered = store.iterator.filter { tuple => rowToString(tuple.key) == "a" }
+    val filtered = store.iterator.filter { tuple => keyRowToData(tuple.key) == ("a", 0) }
     filtered.foreach { tuple =>
-      store.put(tuple.key, intToRow(rowToInt(tuple.value) + 1))
+      store.put(tuple.key, dataToValueRow(valueRowToData(tuple.value) + 1))
     }
-    assert(get(store, "a") === Some(2))
+    assert(get(store, "a", 0) === Some(2))
 
     // Removes should work while iterating of filtered entries
-    val filtered2 = store.iterator.filter { tuple => rowToString(tuple.key) == "b" }
+    val filtered2 = store.iterator.filter { tuple => keyRowToData(tuple.key) == ("b", 0) }
     filtered2.foreach { tuple => store.remove(tuple.key) }
-    assert(get(store, "b") === None)
+    assert(get(store, "b", 0) === None)
   }
 
   testWithAllCodec("abort") {
     val provider = newStoreProvider()
     val store = provider.getStore(0)
-    put(store, "a", 1)
+    put(store, "a", 0, 1)
     store.commit()
-    assert(rowsToSet(store.iterator()) === Set("a" -> 1))
+    assert(rowPairsToDataSet(store.iterator()) === Set(("a", 0) -> 1))
 
     // cancelUpdates should not change the data in the files
     val store1 = provider.getStore(1)
-    put(store1, "b", 1)
+    put(store1, "b", 0, 1)
     store1.abort()
   }
 
@@ -865,22 +900,22 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
     checkInvalidVersion(1)
 
     val store = provider.getStore(0)
-    put(store, "a", 1)
+    put(store, "a", 0, 1)
     assert(store.commit() === 1)
-    assert(rowsToSet(store.iterator()) === Set("a" -> 1))
+    assert(rowPairsToDataSet(store.iterator()) === Set(("a", 0) -> 1))
 
     val store1_ = provider.getStore(1)
-    assert(rowsToSet(store1_.iterator()) === Set("a" -> 1))
+    assert(rowPairsToDataSet(store1_.iterator()) === Set(("a", 0) -> 1))
 
     checkInvalidVersion(-1)
     checkInvalidVersion(2)
 
     // Update store version with some data
     val store1 = provider.getStore(1)
-    assert(rowsToSet(store1.iterator()) === Set("a" -> 1))
-    put(store1, "b", 1)
+    assert(rowPairsToDataSet(store1.iterator()) === Set(("a", 0) -> 1))
+    put(store1, "b", 0, 1)
     assert(store1.commit() === 2)
-    assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1))
+    assert(rowPairsToDataSet(store1.iterator()) === Set(("a", 0) -> 1, ("b", 0) -> 1))
 
     checkInvalidVersion(-1)
     checkInvalidVersion(3)
@@ -897,24 +932,25 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
     val provider0 = newStoreProvider(storeId)
     // prime state
     val store = provider0.getStore(0)
-    val key = "a"
-    put(store, key, 1)
+    val key1 = "a"
+    val key2 = 0
+    put(store, key1, key2, 1)
     store.commit()
-    assert(rowsToSet(store.iterator()) === Set(key -> 1))
+    assert(rowPairsToDataSet(store.iterator()) === Set((key1, key2) -> 1))
 
     // two state stores
     val provider1 = newStoreProvider(storeId)
     val restoreStore = provider1.getReadStore(1)
     val saveStore = provider1.getStore(1)
 
-    put(saveStore, key, get(restoreStore, key).get + 1)
+    put(saveStore, key1, key2, get(restoreStore, key1, key2).get + 1)
     saveStore.commit()
     restoreStore.abort()
 
     // check that state is correct for next batch
     val provider2 = newStoreProvider(storeId)
     val finalStore = provider2.getStore(2)
-    assert(rowsToSet(finalStore.iterator()) === Set(key -> 2))
+    assert(rowPairsToDataSet(finalStore.iterator()) === Set((key1, key2) -> 2))
   }
 
   test("StateStore.get") {
@@ -927,45 +963,45 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
       // Verify that trying to get incorrect versions throw errors
       intercept[IllegalArgumentException] {
         StateStore.get(
-          storeId, keySchema, valueSchema, None, -1, storeConf, hadoopConf)
+          storeId, keySchema, valueSchema, 0, -1, storeConf, hadoopConf)
       }
       assert(!StateStore.isLoaded(storeId)) // version -1 should not attempt to load the store
 
       intercept[IllegalStateException] {
         StateStore.get(
-          storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf)
+          storeId, keySchema, valueSchema, 0, 1, storeConf, hadoopConf)
       }
 
       // Increase version of the store and try to get again
       val store0 = StateStore.get(
-        storeId, keySchema, valueSchema, None, 0, storeConf, hadoopConf)
+        storeId, keySchema, valueSchema, 0, 0, storeConf, hadoopConf)
       assert(store0.version === 0)
-      put(store0, "a", 1)
+      put(store0, "a", 0, 1)
       store0.commit()
 
       val store1 = StateStore.get(
-        storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf)
+        storeId, keySchema, valueSchema, 0, 1, storeConf, hadoopConf)
       assert(StateStore.isLoaded(storeId))
       assert(store1.version === 1)
-      assert(rowsToSet(store1.iterator()) === Set("a" -> 1))
+      assert(rowPairsToDataSet(store1.iterator()) === Set(("a", 0) -> 1))
 
       // Verify that you can also load older version
       val store0reloaded = StateStore.get(
-        storeId, keySchema, valueSchema, None, 0, storeConf, hadoopConf)
+        storeId, keySchema, valueSchema, 0, 0, storeConf, hadoopConf)
       assert(store0reloaded.version === 0)
-      assert(rowsToSet(store0reloaded.iterator()) === Set.empty)
+      assert(rowPairsToDataSet(store0reloaded.iterator()) === Set.empty)
 
       // Verify that you can remove the store and still reload and use it
       StateStore.unload(storeId)
       assert(!StateStore.isLoaded(storeId))
 
       val store1reloaded = StateStore.get(
-        storeId, keySchema, valueSchema, None, 1, storeConf, hadoopConf)
+        storeId, keySchema, valueSchema, 0, 1, storeConf, hadoopConf)
       assert(StateStore.isLoaded(storeId))
       assert(store1reloaded.version === 1)
-      put(store1reloaded, "a", 2)
+      put(store1reloaded, "a", 0, 2)
       assert(store1reloaded.commit() === 2)
-      assert(rowsToSet(store1reloaded.iterator()) === Set("a" -> 2))
+      assert(rowPairsToDataSet(store1reloaded.iterator()) === Set(("a", 0) -> 2))
     }
   }
 
@@ -973,7 +1009,7 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
     val provider = newStoreProvider()
     val store = provider.getStore(0)
     val noDataMemoryUsed = store.metrics.memoryUsedBytes
-    put(store, "a", 1)
+    put(store, "a", 0, 1)
     store.commit()
     assert(store.metrics.memoryUsedBytes > noDataMemoryUsed)
   }
@@ -1008,7 +1044,7 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
 
     val store = provider.getStore(0)
     val err = intercept[IllegalArgumentException] {
-      store.put(stringToRow("key"), null)
+      store.put(dataToKeyRow("key", 0), null)
     }
     assert(err.getMessage.contains("Cannot put a null value"))
   }
@@ -1036,14 +1072,17 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
   /** Return a new provider with minimum delta and version to retain in memory */
   def newStoreProvider(minDeltasForSnapshot: Int, numOfVersToRetainInMemory: Int): ProviderClass
 
+  /** Return a new provider with setting prefix key */
+  def newStoreProvider(numPrefixCols: Int): ProviderClass
+
   /** Get the latest data referred to by the given provider but not using this provider */
-  def getLatestData(storeProvider: ProviderClass): Set[(String, Int)]
+  def getLatestData(storeProvider: ProviderClass): Set[((String, Int), Int)]
 
   /**
    * Get a specific version of data referred to by the given provider but not using
    * this provider
    */
-  def getData(storeProvider: ProviderClass, version: Int): Set[(String, Int)]
+  def getData(storeProvider: ProviderClass, version: Int): Set[((String, Int), Int)]
 
   protected def testQuietly(name: String)(f: => Unit): Unit = {
     test(name) {
@@ -1084,7 +1123,7 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
 
   def incrementVersion(provider: StateStoreProvider, currentVersion: Int): Int = {
     val store = provider.getStore(currentVersion)
-    put(store, "a", currentVersion + 1)
+    put(store, "a", 0, currentVersion + 1)
     store.commit()
     currentVersion + 1
   }
@@ -1104,45 +1143,54 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider]
 
 object StateStoreTestsHelper {
 
-  val strProj = UnsafeProjection.create(Array[DataType](StringType))
-  val intProj = UnsafeProjection.create(Array[DataType](IntegerType))
+  val keySchema = StructType(
+    Seq(StructField("key1", StringType, true), StructField("key2", IntegerType, true)))
+  val valueSchema = StructType(Seq(StructField("value", IntegerType, true)))
+
+  val keyProj = UnsafeProjection.create(Array[DataType](StringType, IntegerType))
+  val prefixKeyProj = UnsafeProjection.create(Array[DataType](StringType))
+  val valueProj = UnsafeProjection.create(Array[DataType](IntegerType))
+
+  def dataToPrefixKeyRow(s: String): UnsafeRow = {
+    prefixKeyProj.apply(new GenericInternalRow(Array[Any](UTF8String.fromString(s)))).copy()
+  }
 
-  def stringToRow(s: String): UnsafeRow = {
-    strProj.apply(new GenericInternalRow(Array[Any](UTF8String.fromString(s)))).copy()
+  def dataToKeyRow(s: String, i: Int): UnsafeRow = {
+    keyProj.apply(new GenericInternalRow(Array[Any](UTF8String.fromString(s), i))).copy()
   }
 
-  def intToRow(i: Int): UnsafeRow = {
-    intProj.apply(new GenericInternalRow(Array[Any](i))).copy()
+  def dataToValueRow(i: Int): UnsafeRow = {
+    valueProj.apply(new GenericInternalRow(Array[Any](i))).copy()
   }
 
-  def rowToString(row: UnsafeRow): String = {
-    row.getUTF8String(0).toString
+  def keyRowToData(row: UnsafeRow): (String, Int) = {
+    (row.getUTF8String(0).toString, row.getInt(1))
   }
 
-  def rowToInt(row: UnsafeRow): Int = {
+  def valueRowToData(row: UnsafeRow): Int = {
     row.getInt(0)
   }
 
-  def rowsToStringInt(row: UnsafeRowPair): (String, Int) = {
-    (rowToString(row.key), rowToInt(row.value))
+  def rowPairToDataPair(row: UnsafeRowPair): ((String, Int), Int) = {
+    (keyRowToData(row.key), valueRowToData(row.value))
   }
 
-  def rowsToSet(iterator: Iterator[UnsafeRowPair]): Set[(String, Int)] = {
-    iterator.map(rowsToStringInt).toSet
+  def rowPairsToDataSet(iterator: Iterator[UnsafeRowPair]): Set[((String, Int), Int)] = {
+    iterator.map(rowPairToDataPair).toSet
   }
 
-  def remove(store: StateStore, condition: String => Boolean): Unit = {
-    store.getRange(None, None).foreach { rowPair =>
-      if (condition(rowToString(rowPair.key))) store.remove(rowPair.key)
+  def remove(store: StateStore, condition: ((String, Int)) => Boolean): Unit = {
+    store.iterator().foreach { rowPair =>
+      if (condition(keyRowToData(rowPair.key))) store.remove(rowPair.key)
     }
   }
 
-  def put(store: StateStore, key: String, value: Int): Unit = {
-    store.put(stringToRow(key), intToRow(value))
+  def put(store: StateStore, key1: String, key2: Int, value: Int): Unit = {
+    store.put(dataToKeyRow(key1, key2), dataToValueRow(value))
   }
 
-  def get(store: ReadStateStore, key: String): Option[Int] = {
-    Option(store.get(stringToRow(key))).map(rowToInt)
+  def get(store: ReadStateStore, key1: String, key2: Int): Option[Int] = {
+    Option(store.get(dataToKeyRow(key1, key2))).map(valueRowToData)
   }
 
   def newDir(): String = Utils.createTempDir().toString
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index fb6922a..e60f706 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -1418,7 +1418,7 @@ class TestStateStoreProvider extends StateStoreProvider {
       stateStoreId: StateStoreId,
       keySchema: StructType,
       valueSchema: StructType,
-      indexOrdinal: Option[Int],
+      numColsPrefixKey: Int,
       storeConfs: StateStoreConf,
       hadoopConf: Configuration): Unit = {
     throw new Exception("Successfully instantiated")

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org