You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ka...@apache.org on 2024/02/21 05:43:40 UTC

(spark) branch master updated: [SPARK-46928][SS] Add support for ListState in Arbitrary State API v2

This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0b907ed11e6e [SPARK-46928][SS] Add support for ListState in Arbitrary State API v2
0b907ed11e6e is described below

commit 0b907ed11e6ec6bc4c7d07926ed352806636d58a
Author: Bhuwan Sahni <bh...@databricks.com>
AuthorDate: Wed Feb 21 14:43:27 2024 +0900

    [SPARK-46928][SS] Add support for ListState in Arbitrary State API v2
    
    ### What changes were proposed in this pull request?
    
    This PR adds changes for ListState implementation in State Api v2. As a list contains multiple values for a single key, we utilize RocksDB merge operator to persist multiple values.
    
    Changes include
    
    1. A new encoder/decoder to encode multiple values inside a single byte[] array (stored in RocksDB). The encoding scheme is compatible with RocksDB StringAppendOperator merge operator.
    2. Support merge operations in ChangelogCheckpointing v2.
    3. Extend StateStore to support merge operation, and read multiple values for a single key (via a Iterator). Note that these changes are only supported for RocksDB currently.
    
    ### Why are the changes needed?
    
    These changes are needed to support list values in the State Store. The changes are part of the work around adding new stateful streaming operator for arbitrary state mgmt that provides a bunch of new features listed in the SPIP JIRA here - https://issues.apache.org/jira/browse/SPARK-45939
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes
    This PR introduces a new state type (ListState) that users can use in their Spark streaming queries.
    
    ### How was this patch tested?
    
    1. Added a new test suite for ListState to ensure the state produces correct results.
    2. Added additional testcases for input validation.
    3. Added tests for merge operator with RocksDB.
    4. Added tests for changelog checkpointing merge operator.
    5. Added tests for reading merged values in RocksDBStateStore.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #44961 from sahnib/state-api-v2-list-state.
    
    Authored-by: Bhuwan Sahni <bh...@databricks.com>
    Signed-off-by: Jungtaek Lim <ka...@gmail.com>
---
 .../src/main/resources/error/error-classes.json    |  18 ++
 ...itions-illegal-state-store-value-error-class.md |  41 +++
 docs/sql-error-conditions.md                       |   8 +
 .../{ValueState.scala => ListState.scala}          |  30 +-
 .../sql/streaming/StatefulProcessorHandle.scala    |  10 +
 .../apache/spark/sql/streaming/ValueState.scala    |   2 +-
 .../v2/state/StatePartitionReader.scala            |   2 +-
 .../sql/execution/streaming/ListStateImpl.scala    | 121 ++++++++
 .../streaming/StateTypesEncoderUtils.scala         |  88 ++++++
 .../streaming/StatefulProcessorHandleImpl.scala    |   8 +-
 .../streaming/TransformWithStateExec.scala         |   6 +-
 .../sql/execution/streaming/ValueStateImpl.scala   |  61 +---
 .../state/HDFSBackedStateStoreProvider.scala       |  27 +-
 .../sql/execution/streaming/state/RocksDB.scala    |  37 +++
 .../streaming/state/RocksDBStateEncoder.scala      |  96 +++++-
 .../state/RocksDBStateStoreProvider.scala          |  53 +++-
 .../sql/execution/streaming/state/StateStore.scala |  53 +++-
 .../streaming/state/StateStoreChangelog.scala      |  48 ++-
 .../streaming/state/StateStoreErrors.scala         |  12 +
 .../execution/streaming/state/StateStoreRDD.scala  |   5 +-
 .../state/SymmetricHashJoinStateManager.scala      |   3 +-
 .../sql/execution/streaming/state/package.scala    |   6 +-
 .../streaming/state/MemoryStateStore.scala         |  11 +-
 .../streaming/state/RocksDBStateStoreSuite.scala   |  56 +++-
 .../execution/streaming/state/RocksDBSuite.scala   |  84 ++++++
 .../streaming/state/StateStoreSuite.scala          |   7 +-
 .../streaming/state/ValueStateSuite.scala          |  12 +-
 .../apache/spark/sql/streaming/StreamSuite.scala   |   3 +-
 .../streaming/TransformWithListStateSuite.scala    | 328 +++++++++++++++++++++
 .../sql/streaming/TransformWithStateSuite.scala    |   2 +-
 30 files changed, 1120 insertions(+), 118 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json
index c1b1171b5dc8..b30b1d60bb4a 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -1380,6 +1380,24 @@
     ],
     "sqlState" : "42601"
   },
+  "ILLEGAL_STATE_STORE_VALUE" : {
+    "message" : [
+      "Illegal value provided to the State Store"
+    ],
+    "subClass" : {
+      "EMPTY_LIST_VALUE" : {
+        "message" : [
+          "Cannot write empty list values to State Store for StateName <stateName>."
+        ]
+      },
+      "NULL_VALUE" : {
+        "message" : [
+          "Cannot write null values to State Store for StateName <stateName>."
+        ]
+      }
+    },
+    "sqlState" : "42601"
+  },
   "INCOMPARABLE_PIVOT_COLUMN" : {
     "message" : [
       "Invalid pivot column <columnName>. Pivot columns must be comparable."
diff --git a/docs/sql-error-conditions-illegal-state-store-value-error-class.md b/docs/sql-error-conditions-illegal-state-store-value-error-class.md
new file mode 100644
index 000000000000..e6457e58b7b4
--- /dev/null
+++ b/docs/sql-error-conditions-illegal-state-store-value-error-class.md
@@ -0,0 +1,41 @@
+---
+layout: global
+title: ILLEGAL_STATE_STORE_VALUE error class
+displayTitle: ILLEGAL_STATE_STORE_VALUE error class
+license: |
+  Licensed to the Apache Software Foundation (ASF) under one or more
+  contributor license agreements.  See the NOTICE file distributed with
+  this work for additional information regarding copyright ownership.
+  The ASF licenses this file to You under the Apache License, Version 2.0
+  (the "License"); you may not use this file except in compliance with
+  the License.  You may obtain a copy of the License at
+
+     http://www.apache.org/licenses/LICENSE-2.0
+
+  Unless required by applicable law or agreed to in writing, software
+  distributed under the License is distributed on an "AS IS" BASIS,
+  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+  See the License for the specific language governing permissions and
+  limitations under the License.
+---
+
+<!--
+  DO NOT EDIT THIS FILE.
+  It was generated automatically by `org.apache.spark.SparkThrowableSuite`.
+-->
+
+[SQLSTATE: 42601](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
+
+Illegal value provided to the State Store
+
+This error class has the following derived error classes:
+
+## EMPTY_LIST_VALUE
+
+Cannot write empty list values to State Store for StateName `<stateName>`.
+
+## NULL_VALUE
+
+Cannot write null values to State Store for StateName `<stateName>`.
+
+
diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md
index 1e5e8aad6196..ebf7436f96e7 100644
--- a/docs/sql-error-conditions.md
+++ b/docs/sql-error-conditions.md
@@ -864,6 +864,14 @@ Sketches have different `lgConfigK` values: `<left>` and `<right>`. Set the `all
 
 `<identifier>` is not a valid identifier as it has more than 2 name parts.
 
+### [ILLEGAL_STATE_STORE_VALUE](sql-error-conditions-illegal-state-store-value-error-class.html)
+
+[SQLSTATE: 42601](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
+
+Illegal value provided to the State Store
+
+For more details see [ILLEGAL_STATE_STORE_VALUE](sql-error-conditions-illegal-state-store-value-error-class.html)
+
 ### INCOMPARABLE_PIVOT_COLUMN
 
 [SQLSTATE: 42818](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala
similarity index 68%
copy from sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala
copy to sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala
index 25f238b3a25e..0e2d6cc3778c 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ListState.scala
@@ -14,37 +14,33 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-
 package org.apache.spark.sql.streaming
 
-import java.io.Serializable
-
 import org.apache.spark.annotation.{Evolving, Experimental}
 
 @Experimental
 @Evolving
 /**
  * Interface used for arbitrary stateful operations with the v2 API to capture
- * single value state.
+ * list value state.
  */
-private[sql] trait ValueState[S] extends Serializable {
+private[sql] trait ListState[S] extends Serializable {
 
   /** Whether state exists or not. */
   def exists(): Boolean
 
-  /**
-   * Get the state value if it exists
-   * @throws java.util.NoSuchElementException if the state does not exist
-   */
-  @throws[NoSuchElementException]
-  def get(): S
+  /** Get the state value. An empty iterator is returned if no value exists. */
+  def get(): Iterator[S]
+
+  /** Update the value of the list. */
+  def put(newState: Array[S]): Unit
 
-  /** Get the state if it exists as an option and None otherwise */
-  def getOption(): Option[S]
+  /** Append an entry to the list */
+  def appendValue(newState: S): Unit
 
-  /** Update the value of the state. */
-  def update(newState: S): Unit
+  /** Append an entire list to the existing value */
+  def appendList(newState: Array[S]): Unit
 
-  /** Remove this state. */
-  def remove(): Unit
+  /** Removes this state for the given grouping key. */
+  def clear(): Unit
 }
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
index 738928b5cc36..5d3390f80f6d 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessorHandle.scala
@@ -38,6 +38,16 @@ private[sql] trait StatefulProcessorHandle extends Serializable {
    */
   def getValueState[T](stateName: String): ValueState[T]
 
+  /**
+   * Creates new or returns existing list state associated with stateName.
+   * The ListState persists values of type T.
+   *
+   * @param stateName  - name of the state variable
+   * @tparam T - type of state variable
+   * @return - instance of ListState of type T that can be used to store state persistently
+   */
+  def getListState[T](stateName: String): ListState[T]
+
   /** Function to return queryInfo for currently running task */
   def getQueryInfo(): QueryInfo
 
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala
index 25f238b3a25e..9c707c8308ab 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/ValueState.scala
@@ -46,5 +46,5 @@ private[sql] trait ValueState[S] extends Serializable {
   def update(newState: S): Unit
 
   /** Remove this state. */
-  def remove(): Unit
+  def clear(): Unit
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
index 4d6174a81624..d9fbb272ecbe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala
@@ -78,7 +78,7 @@ class StatePartitionReader(
 
     StateStoreProvider.createAndInit(
       stateStoreProviderId, keySchema, valueSchema, numColsPrefixKey,
-      useColumnFamilies = false, storeConf, hadoopConf.value)
+      useColumnFamilies = false, storeConf, hadoopConf.value, useMultipleValuesPerKey = false)
   }
 
   private lazy val store: ReadStateStore = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala
new file mode 100644
index 000000000000..b6ed48dab579
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA}
+import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreErrors}
+import org.apache.spark.sql.streaming.ListState
+
+/**
+ * Provides concrete implementation for list of values associated with a state variable
+ * used in the streaming transformWithState operator.
+ *
+ * @param store - reference to the StateStore instance to be used for storing state
+ * @param stateName - name of logical state partition
+ * @tparam S - data type of object that will be stored in the list
+ */
+class ListStateImpl[S](
+     store: StateStore,
+     stateName: String,
+     keyExprEnc: ExpressionEncoder[Any])
+  extends ListState[S] with Logging {
+
+  private val keySerializer = keyExprEnc.createSerializer()
+
+  private val stateTypesEncoder = StateTypesEncoder(keySerializer, stateName)
+
+  store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, numColsPrefixKey = 0,
+    VALUE_ROW_SCHEMA, useMultipleValuesPerKey = true)
+
+  /** Whether state exists or not. */
+   override def exists(): Boolean = {
+     val encodedGroupingKey = stateTypesEncoder.encodeGroupingKey()
+     val stateValue = store.get(encodedGroupingKey, stateName)
+     stateValue != null
+   }
+
+   /**
+    * Get the state value if it exists. If the state does not exist in state store, an
+    * empty iterator is returned.
+    */
+   override def get(): Iterator[S] = {
+     val encodedKey = stateTypesEncoder.encodeGroupingKey()
+     val unsafeRowValuesIterator = store.valuesIterator(encodedKey, stateName)
+     new Iterator[S] {
+       override def hasNext: Boolean = {
+         unsafeRowValuesIterator.hasNext
+       }
+
+       override def next(): S = {
+         val valueUnsafeRow = unsafeRowValuesIterator.next()
+         stateTypesEncoder.decodeValue(valueUnsafeRow)
+       }
+     }
+   }
+
+   /** Update the value of the list. */
+   override def put(newState: Array[S]): Unit = {
+     validateNewState(newState)
+
+     val encodedKey = stateTypesEncoder.encodeGroupingKey()
+     var isFirst = true
+
+     newState.foreach { v =>
+       val encodedValue = stateTypesEncoder.encodeValue(v)
+       if (isFirst) {
+         store.put(encodedKey, encodedValue, stateName)
+         isFirst = false
+       } else {
+          store.merge(encodedKey, encodedValue, stateName)
+       }
+     }
+   }
+
+   /** Append an entry to the list. */
+   override def appendValue(newState: S): Unit = {
+     StateStoreErrors.requireNonNullStateValue(newState, stateName)
+     store.merge(stateTypesEncoder.encodeGroupingKey(),
+         stateTypesEncoder.encodeValue(newState), stateName)
+   }
+
+   /** Append an entire list to the existing value. */
+   override def appendList(newState: Array[S]): Unit = {
+     validateNewState(newState)
+
+     val encodedKey = stateTypesEncoder.encodeGroupingKey()
+     newState.foreach { v =>
+       val encodedValue = stateTypesEncoder.encodeValue(v)
+       store.merge(encodedKey, encodedValue, stateName)
+     }
+   }
+
+   /** Remove this state. */
+   override def clear(): Unit = {
+     store.remove(stateTypesEncoder.encodeGroupingKey(), stateName)
+   }
+
+   private def validateNewState(newState: Array[S]): Unit = {
+     StateStoreErrors.requireNonNullStateValue(newState, stateName)
+     StateStoreErrors.requireNonEmptyListStateValue(newState, stateName)
+
+     newState.foreach { v =>
+       StateStoreErrors.requireNonNullStateValue(v, stateName)
+     }
+   }
+ }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
new file mode 100644
index 000000000000..15d77030d57b
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.commons.lang3.SerializationUtils
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Serializer
+import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.execution.streaming.state.StateStoreErrors
+import org.apache.spark.sql.types.{BinaryType, StructType}
+
+object StateKeyValueRowSchema {
+  val KEY_ROW_SCHEMA: StructType = new StructType().add("key", BinaryType)
+  val VALUE_ROW_SCHEMA: StructType = new StructType().add("value", BinaryType)
+}
+
+/**
+ * Helper class providing APIs to encode the grouping key, and user provided values
+ * to Spark [[UnsafeRow]].
+ *
+ * CAUTION: StateTypesEncoder class instance is *not* thread-safe.
+ * This class reuses the keyProjection and valueProjection for encoding grouping
+ * key and state value respectively. As UnsafeProjection is not thread safe, this
+ * class is also not thread safe.
+ *
+ * @param keySerializer - serializer to serialize the grouping key of type `GK`
+ *     to an [[InternalRow]]
+ * @param stateName - name of logical state partition
+ * @tparam GK - grouping key type
+ */
+class StateTypesEncoder[GK](
+    keySerializer: Serializer[GK],
+    stateName: String) {
+  import org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema._
+
+  private val keyProjection = UnsafeProjection.create(KEY_ROW_SCHEMA)
+  private val valueProjection = UnsafeProjection.create(VALUE_ROW_SCHEMA)
+
+  // TODO: validate places that are trying to encode the key and check if we can eliminate/
+  // add caching for some of these calls.
+  def encodeGroupingKey(): UnsafeRow = {
+    val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption
+    if (keyOption.isEmpty) {
+      throw StateStoreErrors.implicitKeyNotFound(stateName)
+    }
+
+    val groupingKey = keyOption.get.asInstanceOf[GK]
+    val keyByteArr = keySerializer.apply(groupingKey).asInstanceOf[UnsafeRow].getBytes()
+    val keyRow = keyProjection(InternalRow(keyByteArr))
+    keyRow
+  }
+
+  def encodeValue[S](value: S): UnsafeRow = {
+    val valueByteArr = SerializationUtils.serialize(value.asInstanceOf[Serializable])
+    val valueRow = valueProjection(InternalRow(valueByteArr))
+    valueRow
+  }
+
+  def decodeValue[S](row: UnsafeRow): S = {
+    SerializationUtils
+      .deserialize(row.getBinary(0))
+      .asInstanceOf[S]
+  }
+}
+
+object StateTypesEncoder {
+  def apply[GK](
+      keySerializer: Serializer[GK],
+      stateName: String): StateTypesEncoder[GK] = {
+    new StateTypesEncoder[GK](keySerializer, stateName)
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
index 62c97d11c926..56a325a31e33 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala
@@ -22,7 +22,7 @@ import org.apache.spark.TaskContext
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.execution.streaming.state.StateStore
-import org.apache.spark.sql.streaming.{QueryInfo, StatefulProcessorHandle, ValueState}
+import org.apache.spark.sql.streaming.{ListState, QueryInfo, StatefulProcessorHandle, ValueState}
 import org.apache.spark.util.Utils
 
 /**
@@ -132,4 +132,10 @@ class StatefulProcessorHandleImpl(
     store.removeColFamilyIfExists(stateName)
   }
 
+  override def getListState[T](stateName: String): ListState[T] = {
+    verify(currState == CREATED, s"Cannot create state variable with name=$stateName after " +
+      "initialization is complete")
+    val resultState = new ListStateImpl[T](store, stateName, keyEncoder)
+    resultState
+  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
index 818bef5f34a2..5a80fb1209ba 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
@@ -171,7 +171,8 @@ case class TransformWithStateExec(
         numColsPrefixKey = 0,
         session.sqlContext.sessionState,
         Some(session.sqlContext.streams.stateStoreCoordinator),
-        useColumnFamilies = true
+        useColumnFamilies = true,
+        useMultipleValuesPerKey = true
       ) {
         case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
           processData(store, singleIterator)
@@ -202,7 +203,8 @@ case class TransformWithStateExec(
             numColsPrefixKey = 0,
             useColumnFamilies = true,
             storeConf = storeConf,
-            hadoopConf = broadcastedHadoopConf.value)
+            hadoopConf = broadcastedHadoopConf.value,
+            useMultipleValuesPerKey = true)
 
           val store = stateStoreProvider.getStore(0)
           val outputIterator = processData(store, iter)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
index c1d807144df6..a94a49d88325 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala
@@ -16,15 +16,12 @@
  */
 package org.apache.spark.sql.execution.streaming
 
-import org.apache.commons.lang3.SerializationUtils
-
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreErrors}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.streaming.StateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA}
+import org.apache.spark.sql.execution.streaming.state.StateStore
 import org.apache.spark.sql.streaming.ValueState
-import org.apache.spark.sql.types._
 
 /**
  * Class that provides a concrete implementation for a single value state associated with state
@@ -39,34 +36,12 @@ class ValueStateImpl[S](
     stateName: String,
     keyExprEnc: ExpressionEncoder[Any]) extends ValueState[S] with Logging {
 
-  private val schemaForKeyRow: StructType = new StructType().add("key", BinaryType)
-  private val keyEncoder = UnsafeProjection.create(schemaForKeyRow)
   private val keySerializer = keyExprEnc.createSerializer()
 
-  private val schemaForValueRow: StructType = new StructType().add("value", BinaryType)
-  private val valueEncoder = UnsafeProjection.create(schemaForValueRow)
-
-  store.createColFamilyIfAbsent(stateName, schemaForKeyRow, numColsPrefixKey = 0,
-    schemaForValueRow)
-
-  // TODO: validate places that are trying to encode the key and check if we can eliminate/
-  // add caching for some of these calls.
-  private def encodeKey(): UnsafeRow = {
-    val keyOption = ImplicitGroupingKeyTracker.getImplicitKeyOption
-    if (!keyOption.isDefined) {
-      throw StateStoreErrors.implicitKeyNotFound(stateName)
-    }
+  private val stateTypesEncoder = StateTypesEncoder(keySerializer, stateName)
 
-    val keyByteArr = keySerializer.apply(keyOption.get).asInstanceOf[UnsafeRow].getBytes()
-    val keyRow = keyEncoder(InternalRow(keyByteArr))
-    keyRow
-  }
-
-  private def encodeValue(value: S): UnsafeRow = {
-    val valueByteArr = SerializationUtils.serialize(value.asInstanceOf[Serializable])
-    val valueRow = valueEncoder(InternalRow(valueByteArr))
-    valueRow
-  }
+  store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, numColsPrefixKey = 0,
+    VALUE_ROW_SCHEMA)
 
   /** Function to check if state exists. Returns true if present and false otherwise */
   override def exists(): Boolean = {
@@ -75,41 +50,31 @@ class ValueStateImpl[S](
 
   /** Function to return Option of value if exists and None otherwise */
   override def getOption(): Option[S] = {
-    val retRow = getImpl()
-    if (retRow != null) {
-      val resState = SerializationUtils
-        .deserialize(retRow.getBinary(0))
-        .asInstanceOf[S]
-      Some(resState)
-    } else {
-      None
-    }
+    Option(get())
   }
 
   /** Function to return associated value with key if exists and null otherwise */
   override def get(): S = {
     val retRow = getImpl()
     if (retRow != null) {
-      val resState = SerializationUtils
-        .deserialize(retRow.getBinary(0))
-        .asInstanceOf[S]
-      resState
+      stateTypesEncoder.decodeValue[S](retRow)
     } else {
       null.asInstanceOf[S]
     }
   }
 
   private def getImpl(): UnsafeRow = {
-    store.get(encodeKey(), stateName)
+    store.get(stateTypesEncoder.encodeGroupingKey(), stateName)
   }
 
   /** Function to update and overwrite state associated with given key */
   override def update(newState: S): Unit = {
-    store.put(encodeKey(), encodeValue(newState), stateName)
+    store.put(stateTypesEncoder.encodeGroupingKey(),
+      stateTypesEncoder.encodeValue(newState), stateName)
   }
 
   /** Function to remove state for given key */
-  override def remove(): Unit = {
-    store.remove(encodeKey(), stateName)
+  override def clear(): Unit = {
+    store.remove(stateTypesEncoder.encodeGroupingKey(), stateName)
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index b23c83f625d6..01e2e7f26083 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -31,7 +31,7 @@ import org.apache.commons.io.IOUtils
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs._
 
-import org.apache.spark.{SparkConf, SparkEnv, SparkUnsupportedOperationException}
+import org.apache.spark.{SparkConf, SparkEnv}
 import org.apache.spark.internal.Logging
 import org.apache.spark.io.CompressionCodec
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
@@ -94,6 +94,10 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
       Iterator[UnsafeRowPair] = {
       map.prefixScan(prefixKey)
     }
+
+    override def valuesIterator(key: UnsafeRow, colFamilyName: String): Iterator[UnsafeRow] = {
+      throw StateStoreErrors.unsupportedOperationException("multipleValuesPerKey", "HDFSStateStore")
+    }
   }
 
   /** Implementation of [[StateStore]] API which is backed by an HDFS-compatible file system */
@@ -118,8 +122,9 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
         colFamilyName: String,
         keySchema: StructType,
         numColsPrefixKey: Int,
-        valueSchema: StructType): Unit = {
-      throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3193")
+        valueSchema: StructType,
+        useMultipleValuesPerKey: Boolean = false): Unit = {
+      throw StateStoreErrors.multipleColumnFamiliesNotSupported("HDFSStateStoreProvider")
     }
 
     override def get(key: UnsafeRow, colFamilyName: String): UnsafeRow = {
@@ -208,7 +213,16 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
     override def removeColFamilyIfExists(colFamilyName: String): Unit = {
       throw StateStoreErrors.removingColumnFamiliesNotSupported(
         "HDFSBackedStateStoreProvider")
+    }
 
+    override def valuesIterator(key: UnsafeRow, colFamilyName: String): Iterator[UnsafeRow] = {
+      throw StateStoreErrors.unsupportedOperationException("multipleValuesPerKey", "HDFSStateStore")
+    }
+
+    override def merge(key: UnsafeRow,
+        value: UnsafeRow,
+        colFamilyName: String): Unit = {
+      throw StateStoreErrors.unsupportedOperationException("merge", "HDFSStateStore")
     }
   }
 
@@ -255,7 +269,8 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
       numColsPrefixKey: Int,
       useColumnFamilies: Boolean,
       storeConf: StateStoreConf,
-      hadoopConf: Configuration): Unit = {
+      hadoopConf: Configuration,
+      useMultipleValuesPerKey: Boolean = false): Unit = {
     this.stateStoreId_ = stateStoreId
     this.keySchema = keySchema
     this.valueSchema = valueSchema
@@ -268,6 +283,10 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
       throw StateStoreErrors.multipleColumnFamiliesNotSupported("HDFSStateStoreProvider")
     }
 
+    if (useMultipleValuesPerKey) {
+      throw StateStoreErrors.unsupportedOperationException("multipleValuesPerKey", "HDFSStateStore")
+    }
+
     require((keySchema.length == 0 && numColsPrefixKey == 0) ||
       (keySchema.length > numColsPrefixKey), "The number of columns in the key must be " +
       "greater than the number of columns for prefix key!")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
index b3d981e4b25d..e819bf870015 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
@@ -109,6 +109,7 @@ class RocksDB(
   }
 
   columnFamilyOptions.setCompressionType(getCompressionType(conf.compression))
+  columnFamilyOptions.setMergeOperator(new StringAppendOperator())
 
   private val dbOptions =
     new Options(new DBOptions(), columnFamilyOptions) // options to open the RocksDB
@@ -117,6 +118,7 @@ class RocksDB(
   dbOptions.setTableFormatConfig(tableFormatConfig)
   dbOptions.setMaxOpenFiles(conf.maxOpenFiles)
   dbOptions.setAllowFAllocate(conf.allowFAllocate)
+  dbOptions.setMergeOperator(new StringAppendOperator())
 
   if (conf.boundedMemoryUsage) {
     dbOptions.setWriteBufferManager(writeBufferManager)
@@ -228,6 +230,9 @@ class RocksDB(
 
             case RecordType.DELETE_RECORD =>
               remove(key, colFamilyName)
+
+            case RecordType.MERGE_RECORD =>
+              merge(key, value, colFamilyName)
           }
         }
       } finally {
@@ -316,6 +321,38 @@ class RocksDB(
     }
   }
 
+  /**
+   * Merge the given value for the given key. This is equivalent to the Atomic
+   * Read-Modify-Write operation in RocksDB, known as the "Merge" operation. The
+   * modification is appending the provided value to current list of values for
+   * the given key.
+   *
+   * @note This operation requires that the encoder used can decode multiple values for
+   * a key from the values byte array.
+   *
+   * @note This update is not committed to disk until commit() is called.
+   */
+  def merge(
+      key: Array[Byte],
+      value: Array[Byte],
+      colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = {
+    if (!useColumnFamilies) {
+      throw new RuntimeException("Merge operation uses changelog checkpointing v2 which" +
+        " requires column families to be enabled.")
+    }
+    verifyColFamilyExists(colFamilyName)
+
+    if (conf.trackTotalNumberOfRows) {
+      val oldValue = db.get(colFamilyNameToHandleMap(colFamilyName), readOptions, key)
+      if (oldValue == null) {
+        numKeysOnWritingVersion += 1
+      }
+    }
+    db.merge(colFamilyNameToHandleMap(colFamilyName), writeOptions, key, value)
+
+    changelogWriter.foreach(_.merge(key, value, colFamilyName))
+  }
+
   /**
    * Remove the key if present.
    * @note This update is not committed to disk until commit() is called.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
index be1bb4689507..8f58bccd948b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.execution.streaming.state
 
+import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions.{BoundReference, JoinedRow, UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider.{STATE_ENCODING_NUM_VERSION_BYTES, STATE_ENCODING_VERSION}
 import org.apache.spark.sql.types.{StructField, StructType}
@@ -26,14 +27,15 @@ sealed trait RocksDBKeyStateEncoder {
   def supportPrefixKeyScan: Boolean
   def encodePrefixKey(prefixKey: UnsafeRow): Array[Byte]
   def extractPrefixKey(key: UnsafeRow): UnsafeRow
-
   def encodeKey(row: UnsafeRow): Array[Byte]
   def decodeKey(keyBytes: Array[Byte]): UnsafeRow
 }
 
 sealed trait RocksDBValueStateEncoder {
+  def supportsMultipleValuesPerKey: Boolean
   def encodeValue(row: UnsafeRow): Array[Byte]
   def decodeValue(valueBytes: Array[Byte]): UnsafeRow
+  def decodeValues(valueBytes: Array[Byte]): Iterator[UnsafeRow]
 }
 
 object RocksDBStateEncoder {
@@ -47,8 +49,14 @@ object RocksDBStateEncoder {
     }
   }
 
-  def getValueEncoder(valueSchema: StructType): RocksDBValueStateEncoder = {
-    new SingleValueStateEncoder(valueSchema)
+  def getValueEncoder(
+      valueSchema: StructType,
+      useMultipleValuesPerKey: Boolean): RocksDBValueStateEncoder = {
+    if (useMultipleValuesPerKey) {
+      new MultiValuedStateEncoder(valueSchema)
+    } else {
+      new SingleValueStateEncoder(valueSchema)
+    }
   }
 
   /**
@@ -226,6 +234,82 @@ class NoPrefixKeyStateEncoder(keySchema: StructType)
   }
 }
 
+/**
+ * Supports encoding multiple values per key in RocksDB.
+ * A single value is encoded in the format below, where first value is number of bytes
+ * in actual encodedUnsafeRow followed by the encoded value itself.
+ *
+ * |---size(bytes)--|--unsafeRowEncodedBytes--|
+ *
+ * Multiple values are separated by a delimiter character.
+ *
+ * This encoder supports RocksDB StringAppendOperator merge operator. Values encoded can be
+ * merged in RocksDB using merge operation, and all merged values can be read using decodeValues
+ * operation.
+ */
+class MultiValuedStateEncoder(valueSchema: StructType)
+  extends RocksDBValueStateEncoder with Logging {
+
+  import RocksDBStateEncoder._
+
+  // Reusable objects
+  private val valueRow = new UnsafeRow(valueSchema.size)
+
+  override def encodeValue(row: UnsafeRow): Array[Byte] = {
+    val bytes = encodeUnsafeRow(row)
+    val numBytes = bytes.length
+
+    val encodedBytes = new Array[Byte](java.lang.Integer.BYTES + bytes.length)
+    Platform.putInt(encodedBytes, Platform.BYTE_ARRAY_OFFSET, numBytes)
+    Platform.copyMemory(bytes, Platform.BYTE_ARRAY_OFFSET,
+      encodedBytes, java.lang.Integer.BYTES + Platform.BYTE_ARRAY_OFFSET, bytes.length)
+
+    encodedBytes
+  }
+
+  override def decodeValue(valueBytes: Array[Byte]): UnsafeRow = {
+    if (valueBytes == null) {
+      null
+    } else {
+      val numBytes = Platform.getInt(valueBytes, Platform.BYTE_ARRAY_OFFSET)
+      val encodedValue = new Array[Byte](numBytes)
+      Platform.copyMemory(valueBytes, java.lang.Integer.BYTES + Platform.BYTE_ARRAY_OFFSET,
+        encodedValue, Platform.BYTE_ARRAY_OFFSET, numBytes)
+      decodeToUnsafeRow(encodedValue, valueRow)
+    }
+  }
+
+  override def decodeValues(valueBytes: Array[Byte]): Iterator[UnsafeRow] = {
+    if (valueBytes == null) {
+      Seq().iterator
+    } else {
+      new Iterator[UnsafeRow] {
+        private var pos: Int = Platform.BYTE_ARRAY_OFFSET
+        private val maxPos = Platform.BYTE_ARRAY_OFFSET + valueBytes.length
+
+        override def hasNext: Boolean = {
+          pos < maxPos
+        }
+
+        override def next(): UnsafeRow = {
+          val numBytes = Platform.getInt(valueBytes, pos)
+
+          pos += java.lang.Integer.BYTES
+          val encodedValue = new Array[Byte](numBytes)
+          Platform.copyMemory(valueBytes, pos,
+            encodedValue, Platform.BYTE_ARRAY_OFFSET, numBytes)
+
+          pos += numBytes
+          pos += 1 // eat the delimiter character
+          decodeToUnsafeRow(encodedValue, valueRow)
+        }
+      }
+    }
+  }
+
+  override def supportsMultipleValuesPerKey: Boolean = true
+}
+
 /**
  * RocksDB Value Encoder for UnsafeRow that only supports single value.
  *
@@ -257,4 +341,10 @@ class SingleValueStateEncoder(valueSchema: StructType)
   override def decodeValue(valueBytes: Array[Byte]): UnsafeRow = {
     decodeToUnsafeRow(valueBytes, valueRow)
   }
+
+  override def supportsMultipleValuesPerKey: Boolean = false
+
+  override def decodeValues(valueBytes: Array[Byte]): Iterator[UnsafeRow] = {
+    throw new IllegalStateException("This encoder doesn't support multiple values!")
+  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
index 0c3487ba4dd7..7374abdbde98 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
@@ -52,14 +52,15 @@ private[sql] class RocksDBStateStoreProvider
         colFamilyName: String,
         keySchema: StructType,
         numColsPrefixKey: Int,
-        valueSchema: StructType): Unit = {
+        valueSchema: StructType,
+        useMultipleValuesPerKey: Boolean = false): Unit = {
       verify(colFamilyName != StateStore.DEFAULT_COL_FAMILY_NAME,
         s"Failed to create column family with reserved_name=$colFamilyName")
       verify(useColumnFamilies, "Column families are not supported in this store")
       rocksDB.createColFamilyIfAbsent(colFamilyName)
       keyValueEncoderMap.putIfAbsent(colFamilyName,
         (RocksDBStateEncoder.getKeyEncoder(keySchema, numColsPrefixKey),
-         RocksDBStateEncoder.getValueEncoder(valueSchema)))
+         RocksDBStateEncoder.getValueEncoder(valueSchema, useMultipleValuesPerKey)))
     }
 
     override def get(key: UnsafeRow, colFamilyName: String): UnsafeRow = {
@@ -75,6 +76,42 @@ private[sql] class RocksDBStateStoreProvider
       value
     }
 
+    /**
+     * Provides an iterator containing all values of a non-null key.
+     *
+     * Inside RocksDB, the values are merged together and stored as a byte Array.
+     * This operation relies on state store value encoder to be able to split the
+     * single array into multiple values.
+     *
+     * Also see [[MultiValuedStateEncoder]] which supports encoding/decoding multiple
+     * values per key.
+     */
+    override def valuesIterator(key: UnsafeRow, colFamilyName: String): Iterator[UnsafeRow] = {
+      verify(key != null, "Key cannot be null")
+
+      val kvEncoder = keyValueEncoderMap.get(colFamilyName)
+      val valueEncoder = kvEncoder._2
+      val keyEncoder = kvEncoder._1
+
+      verify(valueEncoder.supportsMultipleValuesPerKey, "valuesIterator requires a encoder " +
+      "that supports multiple values for a single key.")
+      val encodedKey = rocksDB.get(keyEncoder.encodeKey(key), colFamilyName)
+      valueEncoder.decodeValues(encodedKey)
+    }
+
+    override def merge(key: UnsafeRow, value: UnsafeRow,
+        colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit = {
+      verify(state == UPDATING, "Cannot put after already committed or aborted")
+      val kvEncoder = keyValueEncoderMap.get(colFamilyName)
+      val keyEncoder = kvEncoder._1
+      val valueEncoder = kvEncoder._2
+      verify(valueEncoder.supportsMultipleValuesPerKey, "Merge operation requires an encoder" +
+        " which supports multiple values for a single key")
+      verify(key != null, "Key cannot be null")
+      require(value != null, "Cannot put a null value")
+      rocksDB.merge(keyEncoder.encodeKey(key), valueEncoder.encodeValue(value), colFamilyName)
+    }
+
     override def put(key: UnsafeRow, value: UnsafeRow, colFamilyName: String): Unit = {
       verify(state == UPDATING, "Cannot put after already committed or aborted")
       verify(key != null, "Key cannot be null")
@@ -228,7 +265,8 @@ private[sql] class RocksDBStateStoreProvider
       numColsPrefixKey: Int,
       useColumnFamilies: Boolean,
       storeConf: StateStoreConf,
-      hadoopConf: Configuration): Unit = {
+      hadoopConf: Configuration,
+      useMultipleValuesPerKey: Boolean = false): Unit = {
     this.stateStoreId_ = stateStoreId
     this.keySchema = keySchema
     this.valueSchema = valueSchema
@@ -240,9 +278,16 @@ private[sql] class RocksDBStateStoreProvider
       (keySchema.length > numColsPrefixKey), "The number of columns in the key must be " +
       "greater than the number of columns for prefix key!")
 
+    if (useMultipleValuesPerKey) {
+      require(numColsPrefixKey == 0, "Both multiple values per key, and prefix key are not " +
+        "supported simultaneously.")
+      require(useColumnFamilies, "Multiple values per key support requires column families to be" +
+        " enabled in RocksDBStateStore.")
+    }
+
     keyValueEncoderMap.putIfAbsent(StateStore.DEFAULT_COL_FAMILY_NAME,
       (RocksDBStateEncoder.getKeyEncoder(keySchema, numColsPrefixKey),
-       RocksDBStateEncoder.getValueEncoder(valueSchema)))
+       RocksDBStateEncoder.getValueEncoder(valueSchema, useMultipleValuesPerKey)))
 
     rocksDB // lazy initialization
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index 7207a4746196..e2eb0c0728d8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -67,6 +67,17 @@ trait ReadStateStore {
   def get(key: UnsafeRow,
     colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): UnsafeRow
 
+  /**
+   * Provides an iterator containing all values of a non-null key. If key does not exist,
+   * an empty iterator is returned. Implementations should make sure to return an empty
+   * iterator if the key does not exist.
+   *
+   * It is expected to throw exception if Spark calls this method without setting
+   * multipleValuesPerKey as true for the column family.
+   */
+  def valuesIterator(key: UnsafeRow,
+      colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRow]
+
   /**
    * Return an iterator containing all the key-value pairs which are matched with
    * the given prefix key.
@@ -116,7 +127,8 @@ trait StateStore extends ReadStateStore {
       colFamilyName: String,
       keySchema: StructType,
       numColsPrefixKey: Int,
-      valueSchema: StructType): Unit
+      valueSchema: StructType,
+      useMultipleValuesPerKey: Boolean = false): Unit
 
   /**
    * Put a new non-null value for a non-null key. Implementations must be aware that the UnsafeRows
@@ -131,6 +143,16 @@ trait StateStore extends ReadStateStore {
   def remove(key: UnsafeRow,
     colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit
 
+  /**
+   * Merges the provided value with existing values of a non-null key. If a existing
+   * value does not exist, this operation behaves as [[StateStore.put()]].
+   *
+   * It is expected to throw exception if Spark calls this method without setting
+   * multipleValuesPerKey as true for the column family.
+   */
+  def merge(key: UnsafeRow, value: UnsafeRow,
+      colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Unit
+
   /**
    * Commit all the updates that have been made to the store, and return the new version.
    * Implementations should ensure that no more updates (puts, removes) can be after a commit in
@@ -182,6 +204,10 @@ class WrappedReadStateStore(store: StateStore) extends ReadStateStore {
   override def prefixScan(prefixKey: UnsafeRow,
     colFamilyName: String = StateStore.DEFAULT_COL_FAMILY_NAME): Iterator[UnsafeRowPair] =
     store.prefixScan(prefixKey, colFamilyName)
+
+  override def valuesIterator(key: UnsafeRow, colFamilyName: String): Iterator[UnsafeRow] = {
+    store.valuesIterator(key, colFamilyName)
+  }
 }
 
 /**
@@ -291,6 +317,8 @@ trait StateStoreProvider {
    *                          families
    * @param storeConfs Configurations used by the StateStores
    * @param hadoopConf Hadoop configuration that could be used by StateStore to save state data
+   * @param useMultipleValuesPerKey Whether the underlying state store needs to support multiple
+   *                                values for a single key.
    */
   def init(
       stateStoreId: StateStoreId,
@@ -299,7 +327,8 @@ trait StateStoreProvider {
       numColsPrefixKey: Int,
       useColumnFamilies: Boolean,
       storeConfs: StateStoreConf,
-      hadoopConf: Configuration): Unit
+      hadoopConf: Configuration,
+      useMultipleValuesPerKey: Boolean = false): Unit
 
   /**
    * Return the id of the StateStores this provider will generate.
@@ -353,10 +382,11 @@ object StateStoreProvider {
       numColsPrefixKey: Int,
       useColumnFamilies: Boolean,
       storeConf: StateStoreConf,
-      hadoopConf: Configuration): StateStoreProvider = {
+      hadoopConf: Configuration,
+      useMultipleValuesPerKey: Boolean): StateStoreProvider = {
     val provider = create(storeConf.providerClass)
     provider.init(providerId.storeId, keySchema, valueSchema, numColsPrefixKey,
-      useColumnFamilies, storeConf, hadoopConf)
+      useColumnFamilies, storeConf, hadoopConf, useMultipleValuesPerKey)
     provider
   }
 
@@ -549,12 +579,13 @@ object StateStore extends Logging {
       version: Long,
       useColumnFamilies: Boolean,
       storeConf: StateStoreConf,
-      hadoopConf: Configuration): ReadStateStore = {
+      hadoopConf: Configuration,
+      useMultipleValuesPerKey: Boolean = false): ReadStateStore = {
     if (version < 0) {
       throw QueryExecutionErrors.unexpectedStateStoreVersion(version)
     }
     val storeProvider = getStateStoreProvider(storeProviderId, keySchema, valueSchema,
-      numColsPrefixKey, useColumnFamilies, storeConf, hadoopConf)
+      numColsPrefixKey, useColumnFamilies, storeConf, hadoopConf, useMultipleValuesPerKey)
     storeProvider.getReadStore(version)
   }
 
@@ -567,12 +598,13 @@ object StateStore extends Logging {
       version: Long,
       useColumnFamilies: Boolean,
       storeConf: StateStoreConf,
-      hadoopConf: Configuration): StateStore = {
+      hadoopConf: Configuration,
+      useMultipleValuesPerKey: Boolean = false): StateStore = {
     if (version < 0) {
       throw QueryExecutionErrors.unexpectedStateStoreVersion(version)
     }
     val storeProvider = getStateStoreProvider(storeProviderId, keySchema, valueSchema,
-      numColsPrefixKey, useColumnFamilies, storeConf, hadoopConf)
+      numColsPrefixKey, useColumnFamilies, storeConf, hadoopConf, useMultipleValuesPerKey)
     storeProvider.getStore(version)
   }
 
@@ -583,7 +615,8 @@ object StateStore extends Logging {
       numColsPrefixKey: Int,
       useColumnFamilies: Boolean,
       storeConf: StateStoreConf,
-      hadoopConf: Configuration): StateStoreProvider = {
+      hadoopConf: Configuration,
+      useMultipleValuesPerKey: Boolean): StateStoreProvider = {
     loadedProviders.synchronized {
       startMaintenanceIfNeeded(storeConf)
 
@@ -617,7 +650,7 @@ object StateStore extends Logging {
           storeProviderId,
           StateStoreProvider.createAndInit(
             storeProviderId, keySchema, valueSchema, numColsPrefixKey,
-            useColumnFamilies, storeConf, hadoopConf)
+            useColumnFamilies, storeConf, hadoopConf, useMultipleValuesPerKey)
         )
       }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala
index d4a1c3fc63c4..30cf49d8e56d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala
@@ -30,6 +30,7 @@ import org.apache.spark.io.CompressionCodec
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.execution.streaming.CheckpointFileManager
 import org.apache.spark.sql.execution.streaming.CheckpointFileManager.CancellableFSDataOutputStream
+import org.apache.spark.sql.execution.streaming.state.RecordType.RecordType
 import org.apache.spark.util.NextIterator
 
 /**
@@ -41,6 +42,7 @@ object RecordType extends Enumeration {
   val EOF_RECORD = Value("eof_record")
   val PUT_RECORD = Value("put_record")
   val DELETE_RECORD = Value("delete_record")
+  val MERGE_RECORD = Value("merge_record")
 
   // Generate byte representation of each record type
   def getRecordTypeAsByte(recordType: RecordType): Byte = {
@@ -48,6 +50,7 @@ object RecordType extends Enumeration {
       case EOF_RECORD => 0x00.toByte
       case PUT_RECORD => 0x01.toByte
       case DELETE_RECORD => 0x10.toByte
+      case MERGE_RECORD => 0x11.toByte
     }
   }
 
@@ -57,6 +60,7 @@ object RecordType extends Enumeration {
       case 0x00 => EOF_RECORD
       case 0x01 => PUT_RECORD
       case 0x10 => DELETE_RECORD
+      case 0x11 => MERGE_RECORD
       case _ => throw new RuntimeException(s"Found invalid record type for value=$byte")
     }
   }
@@ -91,6 +95,8 @@ abstract class StateStoreChangelogWriter(
 
   def delete(key: Array[Byte], colFamilyName: String): Unit
 
+  def merge(key: Array[Byte], value: Array[Byte], colFamilyName: String): Unit
+
   def abort(): Unit = {
     try {
       if (backingFileStream != null) backingFileStream.cancel()
@@ -155,6 +161,11 @@ class StateStoreChangelogWriterV1(
       operationName = "Delete", entity = "changelog writer v1")
   }
 
+  override def merge(key: Array[Byte], value: Array[Byte], colFamilyName: String): Unit = {
+    throw new UnsupportedOperationException("Operation not supported with state " +
+      "changelog writer v1")
+  }
+
   override def commit(): Unit = {
     try {
       // -1 in the key length field mean EOF.
@@ -194,15 +205,7 @@ class StateStoreChangelogWriterV2(
   }
 
   override def put(key: Array[Byte], value: Array[Byte], colFamilyName: String): Unit = {
-    assert(compressedStream != null)
-    compressedStream.write(RecordType.getRecordTypeAsByte(RecordType.PUT_RECORD))
-    compressedStream.writeInt(key.size)
-    compressedStream.write(key)
-    compressedStream.writeInt(value.size)
-    compressedStream.write(value)
-    compressedStream.writeInt(colFamilyName.getBytes.size)
-    compressedStream.write(colFamilyName.getBytes)
-    size += 1
+    writePutOrMergeRecord(key, value, colFamilyName, RecordType.PUT_RECORD)
   }
 
   override def delete(key: Array[Byte]): Unit = {
@@ -222,6 +225,26 @@ class StateStoreChangelogWriterV2(
     size += 1
   }
 
+  override def merge(key: Array[Byte], value: Array[Byte], colFamilyName: String): Unit = {
+    writePutOrMergeRecord(key, value, colFamilyName, RecordType.MERGE_RECORD)
+  }
+
+  private def writePutOrMergeRecord(key: Array[Byte],
+      value: Array[Byte],
+      colFamilyName: String,
+      recordType: RecordType): Unit = {
+    assert(recordType == RecordType.PUT_RECORD || recordType == RecordType.MERGE_RECORD)
+    assert(compressedStream != null)
+    compressedStream.write(RecordType.getRecordTypeAsByte(recordType))
+    compressedStream.writeInt(key.size)
+    compressedStream.write(key)
+    compressedStream.writeInt(value.size)
+    compressedStream.write(value)
+    compressedStream.writeInt(colFamilyName.getBytes.size)
+    compressedStream.write(colFamilyName.getBytes)
+    size += 1
+  }
+
   def commit(): Unit = {
     try {
       // write EOF_RECORD to signal end of file
@@ -352,6 +375,13 @@ class StateStoreChangelogReaderV2(
           (RecordType.DELETE_RECORD, keyBuffer, null,
             colFamilyNameBuffer.map(_.toChar).mkString)
 
+        case RecordType.MERGE_RECORD =>
+          val keyBuffer = parseBuffer(input)
+          val valueBuffer = parseBuffer(input)
+          val colFamilyNameBuffer = parseBuffer(input)
+          (RecordType.MERGE_RECORD, keyBuffer, valueBuffer,
+            colFamilyNameBuffer.map(_.toChar).mkString)
+
         case _ =>
           throw new IOException("Failed to process unknown record type")
       }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
index bbc6d4c78f90..6f4c3d4c9675 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala
@@ -51,6 +51,18 @@ object StateStoreErrors {
     StateStoreUnsupportedOperationException = {
       new StateStoreUnsupportedOperationException(operationName, entity)
     }
+
+  def requireNonNullStateValue(value: Any, stateName: String): Unit = {
+    SparkException.require(value != null,
+      errorClass = "ILLEGAL_STATE_STORE_VALUE.NULL_VALUE",
+      messageParameters = Map("stateName" -> stateName))
+  }
+
+  def requireNonEmptyListStateValue[S](value: Array[S], stateName: String): Unit = {
+    SparkException.require(value.nonEmpty,
+      errorClass = "ILLEGAL_STATE_STORE_VALUE.EMPTY_LIST_VALUE",
+      messageParameters = Map("stateName" -> stateName))
+  }
 }
 
 class StateStoreMultipleColumnFamiliesNotSupportedException(stateStoreProvider: String)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
index cb02c5d3e775..1af51c49eaa5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala
@@ -113,7 +113,8 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
     sessionState: SessionState,
     @transient private val storeCoordinator: Option[StateStoreCoordinatorRef],
     useColumnFamilies: Boolean = false,
-    extraOptions: Map[String, String] = Map.empty)
+    extraOptions: Map[String, String] = Map.empty,
+    useMultipleValuesPerKey: Boolean = false)
   extends BaseStateStoreRDD[T, U](dataRDD, checkpointLocation, queryRunId, operatorId,
     sessionState, storeCoordinator, extraOptions) {
 
@@ -125,7 +126,7 @@ class StateStoreRDD[T: ClassTag, U: ClassTag](
     val inputIter = dataRDD.iterator(partition, ctxt)
     val store = StateStore.get(
       storeProviderId, keySchema, valueSchema, numColsPrefixKey, storeVersion,
-      useColumnFamilies, storeConf, hadoopConfBroadcast.value.value)
+      useColumnFamilies, storeConf, hadoopConfBroadcast.value.value, useMultipleValuesPerKey)
     storeUpdateFunction(store, inputIter)
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
index b2a3ebd89157..b35cf2492666 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
@@ -486,7 +486,8 @@ class SymmetricHashJoinStateManager(
         // This class will manage the state store provider by itself.
         stateStoreProvider = StateStoreProvider.createAndInit(
           storeProviderId, keySchema, valueSchema, numColsPrefixKey = 0,
-          useColumnFamilies = false, storeConf, hadoopConf)
+          useColumnFamilies = false, storeConf, hadoopConf,
+          useMultipleValuesPerKey = false)
         stateStoreProvider.getStore(stateInfo.get.storeVersion)
       }
       logInfo(s"Loaded store ${store.id}")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
index 2e50c59afc2b..12cd7b8a127e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala
@@ -57,7 +57,8 @@ package object state {
         sessionState: SessionState,
         storeCoordinator: Option[StateStoreCoordinatorRef],
         useColumnFamilies: Boolean = false,
-        extraOptions: Map[String, String] = Map.empty)(
+        extraOptions: Map[String, String] = Map.empty,
+        useMultipleValuesPerKey: Boolean = false)(
         storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = {
 
       val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction)
@@ -82,7 +83,8 @@ package object state {
         sessionState,
         storeCoordinator,
         useColumnFamilies,
-        extraOptions)
+        extraOptions,
+        useMultipleValuesPerKey)
     }
 
     /** Map each partition of an RDD along with data in a [[ReadStateStore]]. */
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala
index b7a738786e3f..fa5889891b93 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala
@@ -34,7 +34,8 @@ class MemoryStateStore extends StateStore() {
       colFamilyName: String,
       keySchema: StructType,
       numColsPrefixKey: Int,
-      valueSchema: StructType): Unit = {
+      valueSchema: StructType,
+      useMultipleValuesPerKey: Boolean = false): Unit = {
     throw StateStoreErrors.multipleColumnFamiliesNotSupported("MemoryStateStoreProvider")
   }
 
@@ -64,4 +65,12 @@ class MemoryStateStore extends StateStore() {
   override def prefixScan(prefixKey: UnsafeRow, colFamilyName: String): Iterator[UnsafeRowPair] = {
     throw new UnsupportedOperationException("Doesn't support prefix scan!")
   }
+
+  override def merge(key: UnsafeRow, value: UnsafeRow, colFamilyName: String): Unit = {
+    throw new UnsupportedOperationException("Doesn't support multiple values per key")
+  }
+
+  override def valuesIterator(key: UnsafeRow, colFamilyName: String): Iterator[UnsafeRow] = {
+    throw new UnsupportedOperationException("Doesn't support multiple values per key")
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
index f2811a23fd8a..1e838ccdb023 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala
@@ -182,6 +182,41 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid
     }
   }
 
+  test("validate rocksdb values iterator correctness") {
+    withSQLConf(SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1") {
+      tryWithProviderResource(newStoreProvider(useColumnFamilies = true,
+        useMultipleValuesPerKey = true)) { provider =>
+        val store = provider.getStore(0)
+        // Verify state after updating
+        put(store, "a", 0, 1)
+
+        val iterator0 = store.valuesIterator(dataToKeyRow("a", 0))
+
+        assert(iterator0.hasNext)
+        assert(valueRowToData(iterator0.next()) === 1)
+        assert(!iterator0.hasNext)
+
+        merge(store, "a", 0, 2)
+        merge(store, "a", 0, 3)
+
+        val iterator1 = store.valuesIterator(dataToKeyRow("a", 0))
+
+        (1 to 3).map { i =>
+          assert(iterator1.hasNext)
+          assert(valueRowToData(iterator1.next()) === i)
+        }
+
+        assert(!iterator1.hasNext)
+
+        remove(store, _._1 == "a")
+        val iterator2 = store.valuesIterator(dataToKeyRow("a", 0))
+        assert(!iterator2.hasNext)
+
+        assert(get(store, "a", 0).isEmpty)
+      }
+    }
+  }
+
   override def newStoreProvider(): RocksDBStateStoreProvider = {
     newStoreProvider(StateStoreId(newDir(), Random.nextInt(), 0))
   }
@@ -200,6 +235,14 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid
       useColumnFamilies = useColumnFamilies)
   }
 
+  def newStoreProvider(useColumnFamilies: Boolean,
+      useMultipleValuesPerKey: Boolean): RocksDBStateStoreProvider = {
+    newStoreProvider(StateStoreId(newDir(), Random.nextInt(), 0), numColsPrefixKey = 0,
+      useColumnFamilies = useColumnFamilies,
+      useMultipleValuesPerKey = useMultipleValuesPerKey
+    )
+  }
+
   def newStoreProvider(storeId: StateStoreId, conf: Configuration): RocksDBStateStoreProvider = {
     newStoreProvider(storeId, numColsPrefixKey = -1, conf = conf)
   }
@@ -213,12 +256,19 @@ class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvid
       numColsPrefixKey: Int,
       sqlConf: Option[SQLConf] = None,
       conf: Configuration = new Configuration,
-      useColumnFamilies: Boolean = false): RocksDBStateStoreProvider = {
+      useColumnFamilies: Boolean = false,
+      useMultipleValuesPerKey: Boolean = false): RocksDBStateStoreProvider = {
     val provider = new RocksDBStateStoreProvider()
     provider.init(
-      storeId, keySchema, valueSchema, numColsPrefixKey = numColsPrefixKey,
+      storeId,
+      keySchema,
+      valueSchema,
+      numColsPrefixKey = numColsPrefixKey,
       useColumnFamilies,
-      new StateStoreConf(sqlConf.getOrElse(SQLConf.get)), conf)
+      new StateStoreConf(sqlConf.getOrElse(SQLConf.get)),
+      conf,
+      useMultipleValuesPerKey
+    )
     provider
   }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
index 6a4ad10d9a7f..c8459cba9676 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala
@@ -760,6 +760,9 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared
     (1 to 5).foreach { i =>
       changelogWriter.put(i.toString, i.toString, StateStore.DEFAULT_COL_FAMILY_NAME)
     }
+    (1 to 5).foreach { i =>
+      changelogWriter.merge(i.toString, i.toString, StateStore.DEFAULT_COL_FAMILY_NAME)
+    }
 
     (2 to 4).foreach { j =>
       changelogWriter.delete(j.toString, StateStore.DEFAULT_COL_FAMILY_NAME)
@@ -771,6 +774,9 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared
     val expectedEntries = (1 to 5).map { i =>
       (RecordType.PUT_RECORD, i.toString.getBytes,
         i.toString.getBytes, StateStore.DEFAULT_COL_FAMILY_NAME)
+    } ++ (1 to 5).map { i =>
+      (RecordType.MERGE_RECORD, i.toString.getBytes,
+        i.toString.getBytes, StateStore.DEFAULT_COL_FAMILY_NAME)
     } ++ (2 to 4).map { j =>
       (RecordType.DELETE_RECORD, j.toString.getBytes,
         null, StateStore.DEFAULT_COL_FAMILY_NAME)
@@ -792,9 +798,11 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared
       dfsRootDir.getAbsolutePath, Utils.createTempDir(), new Configuration)
     val changelogWriter = fileManager.getChangeLogWriter(1, true)
     (1 to 5).foreach(i => changelogWriter.put(i.toString, i.toString, testColFamily1))
+    (1 to 5).foreach(i => changelogWriter.merge(i.toString, i.toString, testColFamily1))
     (2 to 4).foreach(j => changelogWriter.delete(j.toString, testColFamily1))
 
     (1 to 5).foreach(i => changelogWriter.put(i.toString, i.toString, testColFamily2))
+    (1 to 5).foreach(i => changelogWriter.merge(i.toString, i.toString, testColFamily2))
     (2 to 4).foreach(j => changelogWriter.delete(j.toString, testColFamily2))
 
     changelogWriter.commit()
@@ -803,6 +811,9 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared
     val expectedEntriesForColFamily1 = (1 to 5).map { i =>
       (RecordType.PUT_RECORD, i.toString.getBytes,
         i.toString.getBytes, testColFamily1)
+    } ++ (1 to 5).map { i =>
+      (RecordType.MERGE_RECORD, i.toString.getBytes,
+        i.toString.getBytes, testColFamily1)
     } ++ (2 to 4).map { j =>
       (RecordType.DELETE_RECORD, j.toString.getBytes,
         null, testColFamily1)
@@ -811,6 +822,9 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared
     val expectedEntriesForColFamily2 = (1 to 5).map { i =>
       (RecordType.PUT_RECORD, i.toString.getBytes,
         i.toString.getBytes, testColFamily2)
+    } ++ (1 to 5).map { i =>
+      (RecordType.MERGE_RECORD, i.toString.getBytes,
+        i.toString.getBytes, testColFamily2)
     } ++ (2 to 4).map { j =>
       (RecordType.DELETE_RECORD, j.toString.getBytes,
         null, testColFamily2)
@@ -845,6 +859,76 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared
     }
   }
 
+  test("ensure merge operation is not supported if column families is not enabled") {
+    withTempDir { dir =>
+      val remoteDir = Utils.createTempDir().toString
+      val conf = dbConf.copy(minDeltasForSnapshot = 5, compactOnCommit = false)
+      new File(remoteDir).delete() // to make sure that the directory gets created
+      withDB(remoteDir, conf = conf, useColumnFamilies = false) { db =>
+        db.load(0)
+        db.put("a", "1")
+        intercept[RuntimeException](
+          db.merge("a", "2")
+        )
+      }
+    }
+  }
+
+  test("RocksDB: ensure merge operation correctness") {
+    withTempDir { dir =>
+      val remoteDir = Utils.createTempDir().toString
+      // minDeltasForSnapshot being 5 ensures that only changelog files are created
+      // for the 3 commits below
+      val conf = dbConf.copy(minDeltasForSnapshot = 5, compactOnCommit = false)
+      new File(remoteDir).delete() // to make sure that the directory gets created
+      withDB(remoteDir, conf = conf, useColumnFamilies = true) { db =>
+        db.load(0)
+        db.createColFamilyIfAbsent("cf1")
+        db.createColFamilyIfAbsent("cf2")
+        db.put("a", "1", "cf1")
+        db.merge("a", "2", "cf1")
+        db.put("a", "3", "cf2")
+        db.commit()
+
+        db.load(1)
+        db.put("a", "2")
+        db.merge("a", "3", "cf1")
+        db.merge("a", "4", "cf2")
+        db.commit()
+
+        db.load(2)
+        db.remove("a", "cf1")
+        db.merge("a", "5")
+        db.merge("a", "6", "cf2")
+        db.commit()
+
+        db.load(1)
+        assert(new String(db.get("a", "cf1")) === "1,2")
+        assert(new String(db.get("a", "cf2")) === "3")
+        assert(db.get("a") === null)
+        assert(db.iterator("cf1").map(toStr).toSet === Set(("a", "1,2")))
+        assert(db.iterator("cf2").map(toStr).toSet === Set(("a", "3")))
+        assert(db.iterator().isEmpty)
+
+        db.load(2)
+        assert(new String(db.get("a", "cf1")) === "1,2,3")
+        assert(new String(db.get("a", "cf2")) === "3,4")
+        assert(new String(db.get("a")) === "2")
+        assert(db.iterator("cf1").map(toStr).toSet === Set(("a", "1,2,3")))
+        assert(db.iterator("cf2").map(toStr).toSet === Set(("a", "3,4")))
+        assert(db.iterator().map(toStr).toSet === Set(("a", "2")))
+
+        db.load(3)
+        assert(db.get("a", "cf1") === null)
+        assert(new String(db.get("a", "cf2")) === "3,4,6")
+        assert(new String(db.get("a")) === "2,5")
+        assert(db.iterator("cf1").isEmpty)
+        assert(db.iterator("cf2").map(toStr).toSet === Set(("a", "3,4,6")))
+        assert(db.iterator().map(toStr).toSet === Set(("a", "2,5")))
+      }
+    }
+  }
+
   testWithColumnFamilies("RocksDBFileManager: delete orphan files",
     TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled =>
     withTempDir { dir =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
index 99e218780e94..a8c7fc05f21e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala
@@ -65,7 +65,8 @@ class FakeStateStoreProviderWithMaintenanceError extends StateStoreProvider {
       numColsPrefixKey: Int,
       useColumnFamilies: Boolean,
       storeConfs: StateStoreConf,
-      hadoopConf: Configuration): Unit = {
+      hadoopConf: Configuration,
+      useMultipleValuesPerKey: Boolean = false): Unit = {
     id = stateStoreId
   }
 
@@ -1624,6 +1625,10 @@ object StateStoreTestsHelper {
     store.put(dataToKeyRow(key1, key2), dataToValueRow(value))
   }
 
+  def merge(store: StateStore, key1: String, key2: Int, value: Int): Unit = {
+    store.merge(dataToKeyRow(key1, key2), dataToValueRow(value))
+  }
+
   def get(store: ReadStateStore, key1: String, key2: Int): Option[Int] = {
     Option(store.get(dataToKeyRow(key1, key2))).map(valueRowToData)
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
index c069046eed40..be77f7a887c7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/ValueStateSuite.scala
@@ -140,7 +140,7 @@ class ValueStateSuite extends SharedSparkSession
       ImplicitGroupingKeyTracker.setImplicitKey("test_key")
       testState.update(123)
       assert(testState.get() === 123)
-      testState.remove()
+      testState.clear()
       assert(!testState.exists())
       assert(testState.get() === null)
 
@@ -150,7 +150,7 @@ class ValueStateSuite extends SharedSparkSession
       testState.update(123)
       assert(testState.get() === 123)
 
-      testState.remove()
+      testState.clear()
       assert(!testState.exists())
       assert(testState.get() === null)
     }
@@ -167,13 +167,13 @@ class ValueStateSuite extends SharedSparkSession
       ImplicitGroupingKeyTracker.setImplicitKey("test_key")
       testState1.update(123)
       assert(testState1.get() === 123)
-      testState1.remove()
+      testState1.clear()
       assert(!testState1.exists())
       assert(testState1.get() === null)
 
       testState2.update(456)
       assert(testState2.get() === 456)
-      testState2.remove()
+      testState2.clear()
       assert(!testState2.exists())
       assert(testState2.get() === null)
 
@@ -189,11 +189,11 @@ class ValueStateSuite extends SharedSparkSession
       testState2.update(456)
       assert(testState2.get() === 456)
 
-      testState1.remove()
+      testState1.clear()
       assert(!testState1.exists())
       assert(testState1.get() === null)
 
-      testState2.remove()
+      testState2.clear()
       assert(!testState2.exists())
       assert(testState2.get() === null)
     }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index 05e5f3ae51a5..fa08a44dc9e5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -1418,7 +1418,8 @@ class TestStateStoreProvider extends StateStoreProvider {
       numColsPrefixKey: Int,
       useColumnFamilies: Boolean,
       storeConfs: StateStoreConf,
-      hadoopConf: Configuration): Unit = {
+      hadoopConf: Configuration,
+      useMultipleValuesPerKey: Boolean = false): Unit = {
     throw new Exception("Successfully instantiated")
   }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
new file mode 100644
index 000000000000..f7ed813badde
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithListStateSuite.scala
@@ -0,0 +1,328 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.apache.spark.SparkIllegalArgumentException
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider}
+import org.apache.spark.sql.internal.SQLConf
+
+case class InputRow(key: String, action: String, value: String)
+
+class TestListStateProcessor
+  extends StatefulProcessor[String, InputRow, (String, String)] {
+
+  @transient var _processorHandle: StatefulProcessorHandle = _
+  @transient var _listState: ListState[String] = _
+
+  override def init(handle: StatefulProcessorHandle, outputMode: OutputMode): Unit = {
+    _processorHandle = handle
+    _listState = handle.getListState("testListState")
+  }
+
+  override def handleInputRows(
+      key: String,
+      rows: Iterator[InputRow],
+      timerValues: TimerValues): Iterator[(String, String)] = {
+
+    var output = List[(String, String)]()
+
+    for (row <- rows) {
+      if (row.action == "emit") {
+        output = (key, row.value) :: output
+      } else if (row.action == "emitAllInState") {
+        _listState.get().foreach { v =>
+          output = (key, v) :: output
+        }
+        _listState.clear()
+      } else if (row.action == "append") {
+        _listState.appendValue(row.value)
+      } else if (row.action == "appendAll") {
+        _listState.appendList(row.value.split(","))
+      } else if (row.action == "put") {
+        _listState.put(row.value.split(","))
+      } else if (row.action == "remove") {
+        _listState.clear()
+      } else if (row.action == "tryAppendingNull") {
+        _listState.appendValue(null)
+      } else if (row.action == "tryAppendingNullValueInList") {
+        _listState.appendList(Array(null))
+      } else if (row.action == "tryAppendingNullList") {
+        _listState.appendList(null)
+      } else if (row.action == "tryPutNullList") {
+        _listState.put(null)
+      } else if (row.action == "tryPuttingNullInList") {
+        _listState.put(Array(null))
+      } else if (row.action == "tryPutEmptyList") {
+        _listState.put(Array())
+      } else if (row.action == "tryAppendingEmptyList") {
+        _listState.appendList(Array())
+      }
+    }
+
+    output.iterator
+  }
+
+  override def close(): Unit = {}
+}
+
+class ToggleSaveAndEmitProcessor
+  extends StatefulProcessor[String, String, String] {
+
+  @transient var _processorHandle: StatefulProcessorHandle = _
+  @transient var _listState: ListState[String] = _
+  @transient var _valueState: ValueState[Boolean] = _
+
+  override def init(handle: StatefulProcessorHandle, outputMode: OutputMode): Unit = {
+    _processorHandle = handle
+    _listState = handle.getListState("testListState")
+    _valueState = handle.getValueState("testValueState")
+  }
+
+  override def handleInputRows(
+      key: String,
+      rows: Iterator[String],
+      timerValues: TimerValues): Iterator[String] = {
+    val valueStateOption = _valueState.getOption()
+
+    if (valueStateOption.isEmpty || !valueStateOption.get) {
+      _listState.appendList(rows.toArray)
+      _valueState.update(true)
+      Seq().iterator
+    } else {
+      _valueState.clear()
+      val storedValues = _listState.get()
+      _listState.clear()
+
+      new Iterator[String] {
+        override def hasNext: Boolean = {
+          rows.hasNext || storedValues.hasNext
+        }
+
+        override def next(): String = {
+          if (rows.hasNext) {
+            rows.next()
+          } else {
+            storedValues.next()
+          }
+        }
+      }
+    }
+  }
+
+  override def close(): Unit = {}
+}
+
+class TransformWithListStateSuite extends StreamTest
+  with AlsoTestWithChangelogCheckpointingEnabled {
+  import testImplicits._
+
+  test("test appending null value in list state throw exception") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+
+      val inputData = MemoryStream[InputRow]
+      val result = inputData.toDS()
+        .groupByKey(x => x.key)
+        .transformWithState(new TestListStateProcessor(),
+          TimeoutMode.NoTimeouts(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update()) (
+        AddData(inputData, InputRow("k1", "tryAppendingNull", "")),
+        ExpectFailure[SparkIllegalArgumentException](e => {
+          assert(e.getMessage.contains("ILLEGAL_STATE_STORE_VALUE.NULL_VALUE"))
+        })
+      )
+    }
+  }
+
+  test("test putting null value in list state throw exception") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+
+      val inputData = MemoryStream[InputRow]
+      val result = inputData.toDS()
+        .groupByKey(x => x.key)
+        .transformWithState(new TestListStateProcessor(),
+          TimeoutMode.NoTimeouts(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update())(
+        AddData(inputData, InputRow("k1", "tryPuttingNullInList", "")),
+        ExpectFailure[SparkIllegalArgumentException](e => {
+          assert(e.getMessage.contains("ILLEGAL_STATE_STORE_VALUE.NULL_VALUE"))
+        })
+      )
+    }
+  }
+
+  test("test putting null list in list state throw exception") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+
+      val inputData = MemoryStream[InputRow]
+      val result = inputData.toDS()
+        .groupByKey(x => x.key)
+        .transformWithState(new TestListStateProcessor(),
+          TimeoutMode.NoTimeouts(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update())(
+        AddData(inputData, InputRow("k1", "tryPutNullList", "")),
+        ExpectFailure[SparkIllegalArgumentException](e => {
+          assert(e.getMessage.contains("ILLEGAL_STATE_STORE_VALUE.NULL_VALUE"))
+        })
+      )
+    }
+  }
+
+  test("test appending null list in list state throw exception") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+
+      val inputData = MemoryStream[InputRow]
+      val result = inputData.toDS()
+        .groupByKey(x => x.key)
+        .transformWithState(new TestListStateProcessor(),
+          TimeoutMode.NoTimeouts(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update())(
+        AddData(inputData, InputRow("k1", "tryAppendingNullList", "")),
+        ExpectFailure[SparkIllegalArgumentException](e => {
+          assert(e.getMessage.contains("ILLEGAL_STATE_STORE_VALUE.NULL_VALUE"))
+        })
+      )
+    }
+  }
+
+  test("test putting empty list in list state throw exception") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+
+      val inputData = MemoryStream[InputRow]
+      val result = inputData.toDS()
+        .groupByKey(x => x.key)
+        .transformWithState(new TestListStateProcessor(),
+          TimeoutMode.NoTimeouts(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update())(
+        AddData(inputData, InputRow("k1", "tryPutEmptyList", "")),
+        ExpectFailure[SparkIllegalArgumentException](e => {
+          assert(e.getMessage.contains("ILLEGAL_STATE_STORE_VALUE.EMPTY_LIST_VALUE"))
+        })
+      )
+    }
+  }
+
+  test("test appending empty list in list state throw exception") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+
+      val inputData = MemoryStream[InputRow]
+      val result = inputData.toDS()
+        .groupByKey(x => x.key)
+        .transformWithState(new TestListStateProcessor(),
+          TimeoutMode.NoTimeouts(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update())(
+        AddData(inputData, InputRow("k1", "tryAppendingEmptyList", "")),
+        ExpectFailure[SparkIllegalArgumentException](e => {
+          assert(e.getMessage.contains("ILLEGAL_STATE_STORE_VALUE.EMPTY_LIST_VALUE"))
+        })
+      )
+    }
+  }
+
+  test("test list state correctness") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+
+      val inputData = MemoryStream[InputRow]
+      val result = inputData.toDS()
+        .groupByKey(x => x.key)
+        .transformWithState(new TestListStateProcessor(),
+          TimeoutMode.NoTimeouts(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update()) (
+        // no interaction test
+        AddData(inputData, InputRow("k1", "emit", "v1")),
+        CheckNewAnswer(("k1", "v1")),
+        // check simple append
+        AddData(inputData, InputRow("k1", "append", "v2")),
+        AddData(inputData, InputRow("k1", "emitAllInState", "")),
+        CheckNewAnswer(("k1", "v2")),
+        // multiple appends are correctly stored and emitted
+        AddData(inputData, InputRow("k2", "append", "v1")),
+        AddData(inputData, InputRow("k1", "append", "v4")),
+        AddData(inputData, InputRow("k2", "append", "v2")),
+        AddData(inputData, InputRow("k1", "emit", "v5")),
+        AddData(inputData, InputRow("k2", "emit", "v3")),
+        CheckNewAnswer(("k1", "v5"), ("k2", "v3")),
+        AddData(inputData, InputRow("k1", "emitAllInState", "")),
+        AddData(inputData, InputRow("k2", "emitAllInState", "")),
+        CheckNewAnswer(("k2", "v1"), ("k2", "v2"), ("k1", "v4")),
+        // check appendAll with append
+        AddData(inputData, InputRow("k3", "appendAll", "v1,v2,v3")),
+        AddData(inputData, InputRow("k3", "emit", "v4")),
+        AddData(inputData, InputRow("k3", "append", "v5")),
+        CheckNewAnswer(("k3", "v4")),
+        AddData(inputData, InputRow("k3", "emitAllInState", "")),
+        CheckNewAnswer(("k3", "v1"), ("k3", "v2"), ("k3", "v3"), ("k3", "v5")),
+        // check removal cleans up all data in state
+        AddData(inputData, InputRow("k4", "append", "v2")),
+        AddData(inputData, InputRow("k4", "appendList", "v3,v4")),
+        AddData(inputData, InputRow("k4", "remove", "")),
+        AddData(inputData, InputRow("k4", "emitAllInState", "")),
+        CheckNewAnswer(),
+        // check put cleans up previous state and adds new state
+        AddData(inputData, InputRow("k5", "appendAll", "v1,v2,v3")),
+        AddData(inputData, InputRow("k5", "append", "v4")),
+        AddData(inputData, InputRow("k5", "put", "v5,v6")),
+        AddData(inputData, InputRow("k5", "emitAllInState", "")),
+        CheckNewAnswer(("k5", "v5"), ("k5", "v6"))
+      )
+    }
+  }
+
+  test("test ValueState And ListState in Processor") {
+    withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+      classOf[RocksDBStateStoreProvider].getName) {
+
+      val inputData = MemoryStream[String]
+      val result = inputData.toDS()
+        .groupByKey(x => x)
+        .transformWithState(new ToggleSaveAndEmitProcessor(),
+          TimeoutMode.NoTimeouts(),
+          OutputMode.Update())
+
+      testStream(result, OutputMode.Update())(
+        AddData(inputData, "k1"),
+        AddData(inputData, "k2"),
+        CheckNewAnswer(),
+        AddData(inputData, "k1"),
+        AddData(inputData, "k2"),
+        CheckNewAnswer("k1", "k1", "k2", "k2")
+      )
+    }
+  }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
index 7b448ac93419..a4a04e0b5077 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
@@ -46,7 +46,7 @@ class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (S
       timerValues: TimerValues): Iterator[(String, String)] = {
     val count = _countState.getOption().getOrElse(0L) + 1
     if (count == 3) {
-      _countState.remove()
+      _countState.clear()
       Iterator.empty
     } else {
       _countState.update(count)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org