You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by dw...@apache.org on 2021/02/24 16:21:01 UTC

[flink] 06/09: [FLINK-21344] Handle heap timers in Rocks state

This is an automated email from the ASF dual-hosted git repository.

dwysakowicz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit a9fef44654b0c154af573f5c27398e27d3351cf9
Author: Dawid Wysakowicz <dw...@apache.org>
AuthorDate: Mon Feb 8 17:09:19 2021 +0100

    [FLINK-21344] Handle heap timers in Rocks state
    
    We serialize the heap timers into the same format as if they were
    actually stored in RocksDB instead of storing them in a raw operator
    state. It lets users change between using heap and RocksDB timers.
---
 .../runtime/state/HeapPriorityQueuesManager.java   | 110 +++++++++
 .../runtime/state/heap/HeapKeyedStateBackend.java  |  73 +-----
 .../state/heap/HeapMetaInfoRestoreOperation.java   |   5 +-
 .../HeapPriorityQueueSnapshotRestoreWrapper.java   |   5 +-
 .../state/heap/HeapPriorityQueueStateSnapshot.java |   5 +
 .../state/heap/HeapSavepointRestoreOperation.java  |   6 +-
 .../streaming/state/RocksDBKeyedStateBackend.java  |  26 ++-
 .../state/RocksDBKeyedStateBackendBuilder.java     |  37 ++-
 .../state/iterator/RocksQueueIterator.java         | 141 ++++++++++++
 .../RocksStatesPerKeyGroupMergeIterator.java       |  23 +-
 .../state/restore/RocksDBFullRestoreOperation.java |  30 +--
 .../RocksDBHeapTimersFullRestoreOperation.java     | 255 +++++++++++++++++++++
 .../snapshot/RocksDBFullSnapshotResources.java     |  26 ++-
 .../state/snapshot/RocksFullSnapshotStrategy.java  |  17 ++
 ...RocksKeyGroupsRocksSingleStateIteratorTest.java |   6 +-
 .../flink/test/state/BackendSwitchSpecs.java       |  16 +-
 .../RocksSavepointStateBackendSwitchTest.java      |  22 +-
 17 files changed, 696 insertions(+), 107 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/HeapPriorityQueuesManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/HeapPriorityQueuesManager.java
new file mode 100644
index 0000000..27d500d
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/HeapPriorityQueuesManager.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.TypeSerializerSchemaCompatibility;
+import org.apache.flink.runtime.state.heap.HeapPriorityQueueElement;
+import org.apache.flink.runtime.state.heap.HeapPriorityQueueSet;
+import org.apache.flink.runtime.state.heap.HeapPriorityQueueSetFactory;
+import org.apache.flink.runtime.state.heap.HeapPriorityQueueSnapshotRestoreWrapper;
+import org.apache.flink.util.FlinkRuntimeException;
+import org.apache.flink.util.StateMigrationException;
+
+import javax.annotation.Nonnull;
+
+import java.util.Map;
+
+/** Manages creating heap priority queues along with their counterpart meta info. */
+@Internal
+public class HeapPriorityQueuesManager {
+
+    private final Map<String, HeapPriorityQueueSnapshotRestoreWrapper<?>> registeredPQStates;
+    private final HeapPriorityQueueSetFactory priorityQueueSetFactory;
+    private final KeyGroupRange keyGroupRange;
+    private final int numberOfKeyGroups;
+
+    public HeapPriorityQueuesManager(
+            Map<String, HeapPriorityQueueSnapshotRestoreWrapper<?>> registeredPQStates,
+            HeapPriorityQueueSetFactory priorityQueueSetFactory,
+            KeyGroupRange keyGroupRange,
+            int numberOfKeyGroups) {
+        this.registeredPQStates = registeredPQStates;
+        this.priorityQueueSetFactory = priorityQueueSetFactory;
+        this.keyGroupRange = keyGroupRange;
+        this.numberOfKeyGroups = numberOfKeyGroups;
+    }
+
+    @SuppressWarnings("unchecked")
+    @Nonnull
+    public <T extends HeapPriorityQueueElement & PriorityComparable<? super T> & Keyed<?>>
+            KeyGroupedInternalPriorityQueue<T> createOrUpdate(
+                    @Nonnull String stateName,
+                    @Nonnull TypeSerializer<T> byteOrderedElementSerializer) {
+
+        final HeapPriorityQueueSnapshotRestoreWrapper<T> existingState =
+                (HeapPriorityQueueSnapshotRestoreWrapper<T>) registeredPQStates.get(stateName);
+
+        if (existingState != null) {
+            TypeSerializerSchemaCompatibility<T> compatibilityResult =
+                    existingState
+                            .getMetaInfo()
+                            .updateElementSerializer(byteOrderedElementSerializer);
+
+            if (compatibilityResult.isIncompatible()) {
+                throw new FlinkRuntimeException(
+                        new StateMigrationException(
+                                "For heap backends, the new priority queue serializer must not be incompatible."));
+            } else {
+                registeredPQStates.put(
+                        stateName,
+                        existingState.forUpdatedSerializer(byteOrderedElementSerializer));
+            }
+
+            return existingState.getPriorityQueue();
+        } else {
+            final RegisteredPriorityQueueStateBackendMetaInfo<T> metaInfo =
+                    new RegisteredPriorityQueueStateBackendMetaInfo<>(
+                            stateName, byteOrderedElementSerializer);
+            return createInternal(metaInfo);
+        }
+    }
+
+    @Nonnull
+    private <T extends HeapPriorityQueueElement & PriorityComparable<? super T> & Keyed<?>>
+            KeyGroupedInternalPriorityQueue<T> createInternal(
+                    RegisteredPriorityQueueStateBackendMetaInfo<T> metaInfo) {
+
+        final String stateName = metaInfo.getName();
+        final HeapPriorityQueueSet<T> priorityQueue =
+                priorityQueueSetFactory.create(stateName, metaInfo.getElementSerializer());
+
+        HeapPriorityQueueSnapshotRestoreWrapper<T> wrapper =
+                new HeapPriorityQueueSnapshotRestoreWrapper<>(
+                        priorityQueue,
+                        metaInfo,
+                        KeyExtractorFunction.forKeyedObjects(),
+                        keyGroupRange,
+                        numberOfKeyGroups);
+
+        registeredPQStates.put(stateName, wrapper);
+        return priorityQueue;
+    }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index 8e6c356..c728fea 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -35,7 +35,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
-import org.apache.flink.runtime.state.KeyExtractorFunction;
+import org.apache.flink.runtime.state.HeapPriorityQueuesManager;
 import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue;
 import org.apache.flink.runtime.state.Keyed;
 import org.apache.flink.runtime.state.KeyedStateFunction;
@@ -43,7 +43,6 @@ import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.LocalRecoveryConfig;
 import org.apache.flink.runtime.state.PriorityComparable;
 import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
-import org.apache.flink.runtime.state.RegisteredPriorityQueueStateBackendMetaInfo;
 import org.apache.flink.runtime.state.SnapshotResult;
 import org.apache.flink.runtime.state.SnapshotStrategyRunner;
 import org.apache.flink.runtime.state.StateSnapshotRestore;
@@ -96,9 +95,6 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
     /** Map of registered Key/Value states. */
     private final Map<String, StateTable<K, ?, ?>> registeredKVStates;
 
-    /** Map of registered priority queue set states. */
-    private final Map<String, HeapPriorityQueueSnapshotRestoreWrapper<?>> registeredPQStates;
-
     /** The configuration for local recovery. */
     private final LocalRecoveryConfig localRecoveryConfig;
 
@@ -113,7 +109,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
     private final StateTableFactory<K> stateTableFactory;
 
     /** Factory for state that is organized as priority queue. */
-    private final HeapPriorityQueueSetFactory priorityQueueSetFactory;
+    private final HeapPriorityQueuesManager priorityQueuesManager;
 
     public HeapKeyedStateBackend(
             TaskKvStateRegistry kvStateRegistry,
@@ -141,12 +137,16 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
                 keyGroupCompressionDecorator,
                 keyContext);
         this.registeredKVStates = registeredKVStates;
-        this.registeredPQStates = registeredPQStates;
         this.localRecoveryConfig = localRecoveryConfig;
-        this.priorityQueueSetFactory = priorityQueueSetFactory;
         this.checkpointStrategyRunner = checkpointStrategyRunner;
         this.savepointStrategyRunner = savepointStrategyRunner;
         this.stateTableFactory = stateTableFactory;
+        this.priorityQueuesManager =
+                new HeapPriorityQueuesManager(
+                        registeredPQStates,
+                        priorityQueueSetFactory,
+                        keyContext.getKeyGroupRange(),
+                        keyContext.getNumberOfKeyGroups());
         LOG.info("Initializing heap keyed state backend with stream factory.");
     }
 
@@ -154,67 +154,13 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
     //  state backend operations
     // ------------------------------------------------------------------------
 
-    @SuppressWarnings("unchecked")
     @Nonnull
     @Override
     public <T extends HeapPriorityQueueElement & PriorityComparable<? super T> & Keyed<?>>
             KeyGroupedInternalPriorityQueue<T> create(
                     @Nonnull String stateName,
                     @Nonnull TypeSerializer<T> byteOrderedElementSerializer) {
-
-        final HeapPriorityQueueSnapshotRestoreWrapper<T> existingState =
-                (HeapPriorityQueueSnapshotRestoreWrapper<T>) registeredPQStates.get(stateName);
-
-        if (existingState != null) {
-            // TODO we implement the simple way of supporting the current functionality, mimicking
-            // keyed state
-            // because this should be reworked in FLINK-9376 and then we should have a common
-            // algorithm over
-            // StateMetaInfoSnapshot that avoids this code duplication.
-
-            TypeSerializerSchemaCompatibility<T> compatibilityResult =
-                    existingState
-                            .getMetaInfo()
-                            .updateElementSerializer(byteOrderedElementSerializer);
-
-            if (compatibilityResult.isIncompatible()) {
-                throw new FlinkRuntimeException(
-                        new StateMigrationException(
-                                "For heap backends, the new priority queue serializer must not be incompatible."));
-            } else {
-                registeredPQStates.put(
-                        stateName,
-                        existingState.forUpdatedSerializer(byteOrderedElementSerializer));
-            }
-
-            return existingState.getPriorityQueue();
-        } else {
-            final RegisteredPriorityQueueStateBackendMetaInfo<T> metaInfo =
-                    new RegisteredPriorityQueueStateBackendMetaInfo<>(
-                            stateName, byteOrderedElementSerializer);
-            return createInternal(metaInfo);
-        }
-    }
-
-    @Nonnull
-    private <T extends HeapPriorityQueueElement & PriorityComparable<? super T> & Keyed<?>>
-            KeyGroupedInternalPriorityQueue<T> createInternal(
-                    RegisteredPriorityQueueStateBackendMetaInfo<T> metaInfo) {
-
-        final String stateName = metaInfo.getName();
-        final HeapPriorityQueueSet<T> priorityQueue =
-                priorityQueueSetFactory.create(stateName, metaInfo.getElementSerializer());
-
-        HeapPriorityQueueSnapshotRestoreWrapper<T> wrapper =
-                new HeapPriorityQueueSnapshotRestoreWrapper<>(
-                        priorityQueue,
-                        metaInfo,
-                        KeyExtractorFunction.forKeyedObjects(),
-                        keyGroupRange,
-                        numberOfKeyGroups);
-
-        registeredPQStates.put(stateName, wrapper);
-        return priorityQueue;
+        return priorityQueuesManager.createOrUpdate(stateName, byteOrderedElementSerializer);
     }
 
     private <N, V> StateTable<K, N, V> tryRegisterStateTable(
@@ -409,7 +355,6 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
     /** Returns the total number of state entries across all keys/namespaces. */
     @VisibleForTesting
-    @SuppressWarnings("unchecked")
     @Override
     public int numKeyValueStateEntries() {
         int sum = 0;
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapMetaInfoRestoreOperation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapMetaInfoRestoreOperation.java
index aecc44b..8badfd2 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapMetaInfoRestoreOperation.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapMetaInfoRestoreOperation.java
@@ -115,9 +115,10 @@ class HeapMetaInfoRestoreOperation<K> {
         return kvStatesById;
     }
 
-    private <T extends HeapPriorityQueueElement & PriorityComparable & Keyed>
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    private <T extends HeapPriorityQueueElement & PriorityComparable<? super T> & Keyed<?>>
             HeapPriorityQueueSnapshotRestoreWrapper<T> createInternal(
-                    RegisteredPriorityQueueStateBackendMetaInfo<T> metaInfo) {
+                    RegisteredPriorityQueueStateBackendMetaInfo metaInfo) {
 
         final String stateName = metaInfo.getName();
         final HeapPriorityQueueSet<T> priorityQueue =
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSnapshotRestoreWrapper.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSnapshotRestoreWrapper.java
index 8b44c72..8564c15 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSnapshotRestoreWrapper.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSnapshotRestoreWrapper.java
@@ -23,7 +23,6 @@ import org.apache.flink.runtime.state.KeyExtractorFunction;
 import org.apache.flink.runtime.state.KeyGroupPartitioner;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.RegisteredPriorityQueueStateBackendMetaInfo;
-import org.apache.flink.runtime.state.StateSnapshot;
 import org.apache.flink.runtime.state.StateSnapshotKeyGroupReader;
 import org.apache.flink.runtime.state.StateSnapshotRestore;
 
@@ -61,10 +60,10 @@ public class HeapPriorityQueueSnapshotRestoreWrapper<T extends HeapPriorityQueue
     @SuppressWarnings("unchecked")
     @Nonnull
     @Override
-    public StateSnapshot stateSnapshot() {
+    public HeapPriorityQueueStateSnapshot<T> stateSnapshot() {
         final T[] queueDump =
                 (T[]) priorityQueue.toArray(new HeapPriorityQueueElement[priorityQueue.size()]);
-        return new HeapPriorityQueueStateSnapshot<>(
+        return new HeapPriorityQueueStateSnapshot<T>(
                 queueDump,
                 keyExtractorFunction,
                 metaInfo.deepCopy(),
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueStateSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueStateSnapshot.java
index e183085..fb597cf 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueStateSnapshot.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueStateSnapshot.java
@@ -117,6 +117,11 @@ public class HeapPriorityQueueStateSnapshot<T> implements StateSnapshot {
         return metaInfo.snapshot();
     }
 
+    @Nonnull
+    public RegisteredPriorityQueueStateBackendMetaInfo<T> getMetaInfo() {
+        return metaInfo;
+    }
+
     @Override
     public void release() {}
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapSavepointRestoreOperation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapSavepointRestoreOperation.java
index 13fdf8f..d52b709 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapSavepointRestoreOperation.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapSavepointRestoreOperation.java
@@ -164,8 +164,8 @@ public class HeapSavepointRestoreOperation<K> implements RestoreOperation<Void>
     @SuppressWarnings("unchecked")
     private void readPriorityQueue(KeyGroupEntry groupEntry, StateMetaInfoSnapshot infoSnapshot)
             throws IOException {
-        DataInputDeserializer keyDeserializer = new DataInputDeserializer(groupEntry.getKey());
-        keyDeserializer.skipBytesToRead(keyGroupPrefixBytes);
+        entryKeyDeserializer.setBuffer(groupEntry.getKey());
+        entryKeyDeserializer.skipBytesToRead(keyGroupPrefixBytes);
         HeapPriorityQueueSnapshotRestoreWrapper<HeapPriorityQueueElement>
                 priorityQueueSnapshotRestoreWrapper =
                         (HeapPriorityQueueSnapshotRestoreWrapper<HeapPriorityQueueElement>)
@@ -174,7 +174,7 @@ public class HeapSavepointRestoreOperation<K> implements RestoreOperation<Void>
                 priorityQueueSnapshotRestoreWrapper
                         .getMetaInfo()
                         .getElementSerializer()
-                        .deserialize(keyDeserializer);
+                        .deserialize(entryKeyDeserializer);
         HeapPriorityQueueSet<HeapPriorityQueueElement> priorityQueue =
                 priorityQueueSnapshotRestoreWrapper.getPriorityQueue();
         priorityQueue.add(timer);
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index 8aab8b8..0f53955 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -44,6 +44,7 @@ import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.CompositeKeySerializationUtils;
+import org.apache.flink.runtime.state.HeapPriorityQueuesManager;
 import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue;
 import org.apache.flink.runtime.state.Keyed;
 import org.apache.flink.runtime.state.KeyedStateHandle;
@@ -58,6 +59,7 @@ import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTran
 import org.apache.flink.runtime.state.StreamCompressionDecorator;
 import org.apache.flink.runtime.state.heap.HeapPriorityQueueElement;
 import org.apache.flink.runtime.state.heap.HeapPriorityQueueSetFactory;
+import org.apache.flink.runtime.state.heap.HeapPriorityQueueSnapshotRestoreWrapper;
 import org.apache.flink.runtime.state.heap.InternalKeyContext;
 import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
 import org.apache.flink.util.FileUtils;
@@ -96,7 +98,6 @@ import java.util.stream.StreamSupport;
 
 import static org.apache.flink.contrib.streaming.state.RocksDBSnapshotTransformFactoryAdaptor.wrapStateSnapshotTransformFactory;
 import static org.apache.flink.runtime.state.SnapshotStrategyRunner.ExecutionType.ASYNCHRONOUS;
-import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkState;
 
 /**
@@ -182,6 +183,8 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
      */
     private final LinkedHashMap<String, RocksDbKvStateInfo> kvStateInformation;
 
+    private final HeapPriorityQueuesManager heapPriorityQueuesManager;
+
     /** Number of bytes required to prefix the key groups. */
     private final int keyGroupPrefixBytes;
 
@@ -240,6 +243,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
             TtlTimeProvider ttlTimeProvider,
             RocksDB db,
             LinkedHashMap<String, RocksDbKvStateInfo> kvStateInformation,
+            Map<String, HeapPriorityQueueSnapshotRestoreWrapper<?>> registeredPQStates,
             int keyGroupPrefixBytes,
             CloseableRegistry cancelStreamRegistry,
             StreamCompressionDecorator keyGroupCompressionDecorator,
@@ -279,7 +283,6 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
         this.writeOptions = optionsContainer.getWriteOptions();
         this.readOptions = optionsContainer.getReadOptions();
-        checkArgument(writeBatchSize >= 0, "Write batch size have to be no negative value.");
         this.writeBatchSize = writeBatchSize;
         this.db = db;
         this.rocksDBResourceGuard = rocksDBResourceGuard;
@@ -290,6 +293,16 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
         this.nativeMetricMonitor = nativeMetricMonitor;
         this.sharedRocksKeyBuilder = sharedRocksKeyBuilder;
         this.priorityQueueFactory = priorityQueueFactory;
+        if (priorityQueueFactory instanceof HeapPriorityQueueSetFactory) {
+            this.heapPriorityQueuesManager =
+                    new HeapPriorityQueuesManager(
+                            registeredPQStates,
+                            (HeapPriorityQueueSetFactory) priorityQueueFactory,
+                            keyContext.getKeyGroupRange(),
+                            keyContext.getNumberOfKeyGroups());
+        } else {
+            this.heapPriorityQueuesManager = null;
+        }
     }
 
     @SuppressWarnings("unchecked")
@@ -459,7 +472,12 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
             KeyGroupedInternalPriorityQueue<T> create(
                     @Nonnull String stateName,
                     @Nonnull TypeSerializer<T> byteOrderedElementSerializer) {
-        return priorityQueueFactory.create(stateName, byteOrderedElementSerializer);
+        if (this.heapPriorityQueuesManager != null) {
+            return this.heapPriorityQueuesManager.createOrUpdate(
+                    stateName, byteOrderedElementSerializer);
+        } else {
+            return priorityQueueFactory.create(stateName, byteOrderedElementSerializer);
+        }
     }
 
     private void cleanInstanceBasePath() {
@@ -640,7 +658,6 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
                     TypeSerializer<SV> stateSerializer)
                     throws Exception {
 
-        @SuppressWarnings("unchecked")
         RegisteredKeyValueStateBackendMetaInfo<N, SV> restoredKvStateMetaInfo = oldStateInfo.f1;
 
         // fetch current serializer now because if it is incompatible, we can't access
@@ -814,7 +831,6 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
     }
 
     @VisibleForTesting
-    @SuppressWarnings("unchecked")
     @Override
     public int numKeyValueStateEntries() {
         int count = 0;
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackendBuilder.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackendBuilder.java
index ce90d05..382a185 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackendBuilder.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackendBuilder.java
@@ -22,6 +22,7 @@ import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.contrib.streaming.state.restore.RocksDBFullRestoreOperation;
+import org.apache.flink.contrib.streaming.state.restore.RocksDBHeapTimersFullRestoreOperation;
 import org.apache.flink.contrib.streaming.state.restore.RocksDBIncrementalRestoreOperation;
 import org.apache.flink.contrib.streaming.state.restore.RocksDBNoneRestoreOperation;
 import org.apache.flink.contrib.streaming.state.restore.RocksDBRestoreOperation;
@@ -45,6 +46,7 @@ import org.apache.flink.runtime.state.SerializedCompositeKeyBuilder;
 import org.apache.flink.runtime.state.StateHandleID;
 import org.apache.flink.runtime.state.StreamCompressionDecorator;
 import org.apache.flink.runtime.state.heap.HeapPriorityQueueSetFactory;
+import org.apache.flink.runtime.state.heap.HeapPriorityQueueSnapshotRestoreWrapper;
 import org.apache.flink.runtime.state.heap.InternalKeyContext;
 import org.apache.flink.runtime.state.heap.InternalKeyContextImpl;
 import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
@@ -249,6 +251,8 @@ public class RocksDBKeyedStateBackendBuilder<K> extends AbstractKeyedStateBacken
         CloseableRegistry cancelStreamRegistryForBackend = new CloseableRegistry();
         LinkedHashMap<String, RocksDBKeyedStateBackend.RocksDbKvStateInfo> kvStateInformation =
                 new LinkedHashMap<>();
+        LinkedHashMap<String, HeapPriorityQueueSnapshotRestoreWrapper<?>> registeredPQStates =
+                new LinkedHashMap<>();
         RocksDB db = null;
         RocksDBRestoreOperation restoreOperation = null;
         RocksDbTtlCompactFiltersManager ttlCompactFiltersManager =
@@ -262,6 +266,7 @@ public class RocksDBKeyedStateBackendBuilder<K> extends AbstractKeyedStateBacken
         int keyGroupPrefixBytes =
                 CompositeKeySerializationUtils.computeRequiredBytesInKeyGroupPrefix(
                         numberOfKeyGroups);
+
         try {
             // Variables for snapshot strategy when incremental checkpoint is enabled
             UUID backendUID = UUID.randomUUID();
@@ -282,6 +287,7 @@ public class RocksDBKeyedStateBackendBuilder<K> extends AbstractKeyedStateBacken
                                 keyGroupPrefixBytes,
                                 cancelStreamRegistry,
                                 kvStateInformation,
+                                registeredPQStates,
                                 ttlCompactFiltersManager);
                 RocksDBRestoreResult restoreResult = restoreOperation.restore();
                 db = restoreResult.getDb();
@@ -297,6 +303,7 @@ public class RocksDBKeyedStateBackendBuilder<K> extends AbstractKeyedStateBacken
             writeBatchWrapper =
                     new RocksDBWriteBatchWrapper(
                             db, optionsContainer.getWriteOptions(), writeBatchSize);
+
             // it is important that we only create the key builder after the restore, and not
             // before;
             // restore operations may reconfigure the key serializer, so accessing the key
@@ -313,6 +320,7 @@ public class RocksDBKeyedStateBackendBuilder<K> extends AbstractKeyedStateBacken
                             cancelStreamRegistryForBackend,
                             rocksDBResourceGuard,
                             kvStateInformation,
+                            registeredPQStates,
                             keyGroupPrefixBytes,
                             db,
                             backendUID,
@@ -377,6 +385,7 @@ public class RocksDBKeyedStateBackendBuilder<K> extends AbstractKeyedStateBacken
                 this.ttlTimeProvider,
                 db,
                 kvStateInformation,
+                registeredPQStates,
                 keyGroupPrefixBytes,
                 cancelStreamRegistryForBackend,
                 this.keyGroupCompressionDecorator,
@@ -397,6 +406,7 @@ public class RocksDBKeyedStateBackendBuilder<K> extends AbstractKeyedStateBacken
             int keyGroupPrefixBytes,
             CloseableRegistry cancelStreamRegistry,
             LinkedHashMap<String, RocksDBKeyedStateBackend.RocksDbKvStateInfo> kvStateInformation,
+            LinkedHashMap<String, HeapPriorityQueueSnapshotRestoreWrapper<?>> registeredPQStates,
             RocksDbTtlCompactFiltersManager ttlCompactFiltersManager) {
         DBOptions dbOptions = optionsContainer.getDbOptions();
         if (restoreStateHandles.isEmpty()) {
@@ -431,6 +441,24 @@ public class RocksDBKeyedStateBackendBuilder<K> extends AbstractKeyedStateBacken
                     ttlCompactFiltersManager,
                     writeBatchSize,
                     optionsContainer.getWriteBufferManagerCapacity());
+        } else if (priorityQueueStateType == RocksDBStateBackend.PriorityQueueStateType.HEAP) {
+            return new RocksDBHeapTimersFullRestoreOperation<>(
+                    keyGroupRange,
+                    numberOfKeyGroups,
+                    userCodeClassLoader,
+                    kvStateInformation,
+                    registeredPQStates,
+                    createHeapQueueFactory(),
+                    keySerializerProvider,
+                    instanceRocksDBPath,
+                    dbOptions,
+                    columnFamilyOptionsFactory,
+                    nativeMetricOptions,
+                    metricGroup,
+                    restoreStateHandles,
+                    ttlCompactFiltersManager,
+                    writeBatchSize,
+                    optionsContainer.getWriteBufferManagerCapacity());
         } else {
             return new RocksDBFullRestoreOperation<>(
                     keyGroupRange,
@@ -453,6 +481,7 @@ public class RocksDBKeyedStateBackendBuilder<K> extends AbstractKeyedStateBacken
             CloseableRegistry cancelStreamRegistry,
             ResourceGuard rocksDBResourceGuard,
             LinkedHashMap<String, RocksDBKeyedStateBackend.RocksDbKvStateInfo> kvStateInformation,
+            LinkedHashMap<String, HeapPriorityQueueSnapshotRestoreWrapper<?>> registeredPQStates,
             int keyGroupPrefixBytes,
             RocksDB db,
             UUID backendUID,
@@ -464,6 +493,7 @@ public class RocksDBKeyedStateBackendBuilder<K> extends AbstractKeyedStateBacken
                         rocksDBResourceGuard,
                         keySerializerProvider.currentSchemaSerializer(),
                         kvStateInformation,
+                        registeredPQStates,
                         keyGroupRange,
                         keyGroupPrefixBytes,
                         localRecoveryConfig,
@@ -502,8 +532,7 @@ public class RocksDBKeyedStateBackendBuilder<K> extends AbstractKeyedStateBacken
         PriorityQueueSetFactory priorityQueueFactory;
         switch (priorityQueueStateType) {
             case HEAP:
-                priorityQueueFactory =
-                        new HeapPriorityQueueSetFactory(keyGroupRange, numberOfKeyGroups, 128);
+                priorityQueueFactory = createHeapQueueFactory();
                 break;
             case ROCKSDB:
                 priorityQueueFactory =
@@ -526,6 +555,10 @@ public class RocksDBKeyedStateBackendBuilder<K> extends AbstractKeyedStateBacken
         return priorityQueueFactory;
     }
 
+    private HeapPriorityQueueSetFactory createHeapQueueFactory() {
+        return new HeapPriorityQueueSetFactory(keyGroupRange, numberOfKeyGroups, 128);
+    }
+
     private void prepareDirectories() throws IOException {
         checkAndCreateDirectory(instanceBasePath);
         if (instanceRocksDBPath.exists()) {
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/iterator/RocksQueueIterator.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/iterator/RocksQueueIterator.java
new file mode 100644
index 0000000..b4948cc
--- /dev/null
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/iterator/RocksQueueIterator.java
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.contrib.streaming.state.iterator;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataOutputSerializer;
+import org.apache.flink.runtime.state.CompositeKeySerializationUtils;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.heap.HeapPriorityQueueStateSnapshot;
+import org.apache.flink.util.FlinkRuntimeException;
+
+import java.io.IOException;
+import java.util.Iterator;
+
+/** An iterator over heap timers that produces rocks compatible binary format. */
+public final class RocksQueueIterator implements SingleStateIterator {
+
+    private static final byte[] EMPTY_BYTE_ARRAY = new byte[0];
+
+    private final DataOutputSerializer keyOut = new DataOutputSerializer(128);
+    private final HeapPriorityQueueStateSnapshot<?> queueSnapshot;
+    private final Iterator<Integer> keyGroupRangeIterator;
+    private final int kvStateId;
+    private final int keyGroupPrefixBytes;
+    private final TypeSerializer<Object> elementSerializer;
+
+    private Iterator<Object> elementsForKeyGroup;
+    private int afterKeyMark = 0;
+
+    private boolean isValid;
+    private byte[] currentKey;
+
+    public RocksQueueIterator(
+            HeapPriorityQueueStateSnapshot<?> queuesSnapshot,
+            KeyGroupRange keyGroupRange,
+            int keyGroupPrefixBytes,
+            int kvStateId) {
+        this.queueSnapshot = queuesSnapshot;
+        this.elementSerializer = castToType(queuesSnapshot.getMetaInfo().getElementSerializer());
+        this.keyGroupRangeIterator = keyGroupRange.iterator();
+        this.keyGroupPrefixBytes = keyGroupPrefixBytes;
+        this.kvStateId = kvStateId;
+        if (keyGroupRangeIterator.hasNext()) {
+            try {
+                if (moveToNextNonEmptyKeyGroup()) {
+                    isValid = true;
+                    next();
+                } else {
+                    isValid = false;
+                }
+            } catch (IOException e) {
+                throw new FlinkRuntimeException(e);
+            }
+        }
+    }
+
+    @Override
+    public void next() {
+        try {
+            if (!elementsForKeyGroup.hasNext()) {
+                boolean hasElement = moveToNextNonEmptyKeyGroup();
+                if (!hasElement) {
+                    isValid = false;
+                    return;
+                }
+            }
+            keyOut.setPosition(afterKeyMark);
+            elementSerializer.serialize(elementsForKeyGroup.next(), keyOut);
+            this.currentKey = keyOut.getCopyOfBuffer();
+        } catch (IOException e) {
+            throw new FlinkRuntimeException(e);
+        }
+    }
+
+    private boolean moveToNextNonEmptyKeyGroup() throws IOException {
+        while (keyGroupRangeIterator.hasNext()) {
+            Integer keyGroupId = keyGroupRangeIterator.next();
+            elementsForKeyGroup = castToType(queueSnapshot.getIteratorForKeyGroup(keyGroupId));
+            if (elementsForKeyGroup.hasNext()) {
+                writeKeyGroupId(keyGroupId);
+                return true;
+            }
+        }
+        return false;
+    }
+
+    private void writeKeyGroupId(Integer keyGroupId) throws IOException {
+        keyOut.clear();
+        CompositeKeySerializationUtils.writeKeyGroup(keyGroupId, keyGroupPrefixBytes, keyOut);
+        afterKeyMark = keyOut.length();
+    }
+
+    @SuppressWarnings("unchecked")
+    private static <T> TypeSerializer<T> castToType(TypeSerializer<?> typeSerializer) {
+        return (TypeSerializer<T>) typeSerializer;
+    }
+
+    @SuppressWarnings("unchecked")
+    private static <T> Iterator<T> castToType(Iterator<?> iterator) {
+        return (Iterator<T>) iterator;
+    }
+
+    @Override
+    public boolean isValid() {
+        return isValid;
+    }
+
+    @Override
+    public byte[] key() {
+        return currentKey;
+    }
+
+    @Override
+    public byte[] value() {
+        return EMPTY_BYTE_ARRAY;
+    }
+
+    @Override
+    public int getKvStateId() {
+        return kvStateId;
+    }
+
+    @Override
+    public void close() {}
+}
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/iterator/RocksStatesPerKeyGroupMergeIterator.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/iterator/RocksStatesPerKeyGroupMergeIterator.java
index 613d181..ed8cc0d 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/iterator/RocksStatesPerKeyGroupMergeIterator.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/iterator/RocksStatesPerKeyGroupMergeIterator.java
@@ -73,6 +73,7 @@ public class RocksStatesPerKeyGroupMergeIterator implements KeyValueStateIterato
     public RocksStatesPerKeyGroupMergeIterator(
             final CloseableRegistry closeableRegistry,
             List<Tuple2<RocksIteratorWrapper, Integer>> kvStateIterators,
+            List<SingleStateIterator> heapPriorityQueueIterators,
             final int keyGroupPrefixByteCount)
             throws IOException {
         Preconditions.checkNotNull(closeableRegistry);
@@ -82,8 +83,8 @@ public class RocksStatesPerKeyGroupMergeIterator implements KeyValueStateIterato
         this.closeableRegistry = closeableRegistry;
         this.keyGroupPrefixByteCount = keyGroupPrefixByteCount;
 
-        if (kvStateIterators.size() > 0) {
-            this.heap = buildIteratorHeap(kvStateIterators);
+        if (kvStateIterators.size() > 0 || heapPriorityQueueIterators.size() > 0) {
+            this.heap = buildIteratorHeap(kvStateIterators, heapPriorityQueueIterators);
             this.valid = !heap.isEmpty();
             this.currentSubIterator = heap.poll();
             kvStateIterators.clear();
@@ -129,13 +130,17 @@ public class RocksStatesPerKeyGroupMergeIterator implements KeyValueStateIterato
     }
 
     private PriorityQueue<SingleStateIterator> buildIteratorHeap(
-            List<Tuple2<RocksIteratorWrapper, Integer>> kvStateIterators) throws IOException {
+            List<Tuple2<RocksIteratorWrapper, Integer>> kvStateIterators,
+            List<SingleStateIterator> heapPriorityQueueIterators)
+            throws IOException {
 
         Comparator<SingleStateIterator> iteratorComparator =
                 COMPARATORS.get(keyGroupPrefixByteCount - 1);
 
         PriorityQueue<SingleStateIterator> iteratorPriorityQueue =
-                new PriorityQueue<>(kvStateIterators.size(), iteratorComparator);
+                new PriorityQueue<>(
+                        kvStateIterators.size() + heapPriorityQueueIterators.size(),
+                        iteratorComparator);
 
         for (Tuple2<RocksIteratorWrapper, Integer> rocksIteratorWithKVStateId : kvStateIterators) {
             final RocksIteratorWrapper rocksIterator = rocksIteratorWithKVStateId.f0;
@@ -152,6 +157,16 @@ public class RocksStatesPerKeyGroupMergeIterator implements KeyValueStateIterato
                 }
             }
         }
+
+        for (SingleStateIterator heapQueueIterator : heapPriorityQueueIterators) {
+            if (heapQueueIterator.isValid()) {
+                iteratorPriorityQueue.offer(heapQueueIterator);
+                closeableRegistry.registerCloseable(heapQueueIterator);
+            } else {
+                IOUtils.closeQuietly(heapQueueIterator);
+            }
+        }
+
         return iteratorPriorityQueue;
     }
 
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBFullRestoreOperation.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBFullRestoreOperation.java
index 7b5608a..4005add 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBFullRestoreOperation.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBFullRestoreOperation.java
@@ -45,10 +45,10 @@ import javax.annotation.Nonnull;
 import java.io.File;
 import java.io.IOException;
 import java.util.Collection;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.function.Function;
-import java.util.stream.Collectors;
 
 /** Encapsulates the process of restoring a RocksDB instance from a full snapshot. */
 public class RocksDBFullRestoreOperation<K> implements RocksDBRestoreOperation {
@@ -115,16 +115,14 @@ public class RocksDBFullRestoreOperation<K> implements RocksDBRestoreOperation {
             throws IOException, RocksDBException, StateMigrationException {
         List<StateMetaInfoSnapshot> restoredMetaInfos =
                 savepointRestoreResult.getStateMetaInfoSnapshots();
-        List<ColumnFamilyHandle> columnFamilyHandles =
-                restoredMetaInfos.stream()
-                        .map(
-                                stateMetaInfoSnapshot -> {
-                                    RocksDbKvStateInfo registeredStateCFHandle =
-                                            this.rocksHandle.getOrRegisterStateColumnFamilyHandle(
-                                                    null, stateMetaInfoSnapshot);
-                                    return registeredStateCFHandle.columnFamilyHandle;
-                                })
-                        .collect(Collectors.toList());
+        Map<Integer, ColumnFamilyHandle> columnFamilyHandles = new HashMap<>();
+        for (int i = 0; i < restoredMetaInfos.size(); i++) {
+            StateMetaInfoSnapshot restoredMetaInfo = restoredMetaInfos.get(i);
+            RocksDbKvStateInfo registeredStateCFHandle =
+                    this.rocksHandle.getOrRegisterStateColumnFamilyHandle(null, restoredMetaInfo);
+            columnFamilyHandles.put(i, registeredStateCFHandle.columnFamilyHandle);
+        }
+
         try (ThrowingIterator<KeyGroup> keyGroups = savepointRestoreResult.getRestoredKeyGroups()) {
             restoreKVStateData(keyGroups, columnFamilyHandles);
         }
@@ -135,17 +133,23 @@ public class RocksDBFullRestoreOperation<K> implements RocksDBRestoreOperation {
      * handle.
      */
     private void restoreKVStateData(
-            ThrowingIterator<KeyGroup> keyGroups, List<ColumnFamilyHandle> columnFamilies)
+            ThrowingIterator<KeyGroup> keyGroups, Map<Integer, ColumnFamilyHandle> columnFamilies)
             throws IOException, RocksDBException, StateMigrationException {
         // for all key-groups in the current state handle...
         try (RocksDBWriteBatchWrapper writeBatchWrapper =
                 new RocksDBWriteBatchWrapper(this.rocksHandle.getDb(), writeBatchSize)) {
+            ColumnFamilyHandle handle = null;
             while (keyGroups.hasNext()) {
                 KeyGroup keyGroup = keyGroups.next();
                 try (ThrowingIterator<KeyGroupEntry> groupEntries = keyGroup.getKeyGroupEntries()) {
+                    int oldKvStateId = -1;
                     while (groupEntries.hasNext()) {
                         KeyGroupEntry groupEntry = groupEntries.next();
-                        ColumnFamilyHandle handle = columnFamilies.get(groupEntry.getKvStateId());
+                        int kvStateId = groupEntry.getKvStateId();
+                        if (kvStateId != oldKvStateId) {
+                            oldKvStateId = kvStateId;
+                            handle = columnFamilies.get(kvStateId);
+                        }
                         writeBatchWrapper.put(handle, groupEntry.getKey(), groupEntry.getValue());
                     }
                 }
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBHeapTimersFullRestoreOperation.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBHeapTimersFullRestoreOperation.java
new file mode 100644
index 0000000..0c859e8
--- /dev/null
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBHeapTimersFullRestoreOperation.java
@@ -0,0 +1,255 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.contrib.streaming.state.restore;
+
+import org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend.RocksDbKvStateInfo;
+import org.apache.flink.contrib.streaming.state.RocksDBNativeMetricOptions;
+import org.apache.flink.contrib.streaming.state.RocksDBWriteBatchWrapper;
+import org.apache.flink.contrib.streaming.state.ttl.RocksDbTtlCompactFiltersManager;
+import org.apache.flink.core.memory.DataInputDeserializer;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.runtime.state.CompositeKeySerializationUtils;
+import org.apache.flink.runtime.state.KeyExtractorFunction;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.Keyed;
+import org.apache.flink.runtime.state.KeyedStateHandle;
+import org.apache.flink.runtime.state.PriorityComparable;
+import org.apache.flink.runtime.state.RegisteredPriorityQueueStateBackendMetaInfo;
+import org.apache.flink.runtime.state.StateSerializerProvider;
+import org.apache.flink.runtime.state.heap.HeapPriorityQueueElement;
+import org.apache.flink.runtime.state.heap.HeapPriorityQueueSet;
+import org.apache.flink.runtime.state.heap.HeapPriorityQueueSetFactory;
+import org.apache.flink.runtime.state.heap.HeapPriorityQueueSnapshotRestoreWrapper;
+import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
+import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot.BackendStateType;
+import org.apache.flink.runtime.state.restore.FullSnapshotRestoreOperation;
+import org.apache.flink.runtime.state.restore.KeyGroup;
+import org.apache.flink.runtime.state.restore.KeyGroupEntry;
+import org.apache.flink.runtime.state.restore.SavepointRestoreResult;
+import org.apache.flink.runtime.state.restore.ThrowingIterator;
+import org.apache.flink.util.StateMigrationException;
+
+import org.rocksdb.ColumnFamilyHandle;
+import org.rocksdb.ColumnFamilyOptions;
+import org.rocksdb.DBOptions;
+import org.rocksdb.RocksDBException;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+/** Encapsulates the process of restoring a RocksDB instance from a full snapshot. */
+public class RocksDBHeapTimersFullRestoreOperation<K> implements RocksDBRestoreOperation {
+    private final FullSnapshotRestoreOperation<K> savepointRestoreOperation;
+    /** Write batch size used in {@link RocksDBWriteBatchWrapper}. */
+    private final long writeBatchSize;
+
+    private final LinkedHashMap<String, HeapPriorityQueueSnapshotRestoreWrapper<?>>
+            registeredPQStates;
+    private final HeapPriorityQueueSetFactory priorityQueueFactory;
+    private final int numberOfKeyGroups;
+    private final DataInputDeserializer deserializer = new DataInputDeserializer();
+
+    private final RocksDBHandle rocksHandle;
+    private final KeyGroupRange keyGroupRange;
+    private final int keyGroupPrefixBytes;
+
+    public RocksDBHeapTimersFullRestoreOperation(
+            KeyGroupRange keyGroupRange,
+            int numberOfKeyGroups,
+            ClassLoader userCodeClassLoader,
+            Map<String, RocksDbKvStateInfo> kvStateInformation,
+            LinkedHashMap<String, HeapPriorityQueueSnapshotRestoreWrapper<?>> registeredPQStates,
+            HeapPriorityQueueSetFactory priorityQueueFactory,
+            StateSerializerProvider<K> keySerializerProvider,
+            File instanceRocksDBPath,
+            DBOptions dbOptions,
+            Function<String, ColumnFamilyOptions> columnFamilyOptionsFactory,
+            RocksDBNativeMetricOptions nativeMetricOptions,
+            MetricGroup metricGroup,
+            @Nonnull Collection<KeyedStateHandle> restoreStateHandles,
+            @Nonnull RocksDbTtlCompactFiltersManager ttlCompactFiltersManager,
+            @Nonnegative long writeBatchSize,
+            Long writeBufferManagerCapacity) {
+        this.writeBatchSize = writeBatchSize;
+        this.rocksHandle =
+                new RocksDBHandle(
+                        kvStateInformation,
+                        instanceRocksDBPath,
+                        dbOptions,
+                        columnFamilyOptionsFactory,
+                        nativeMetricOptions,
+                        metricGroup,
+                        ttlCompactFiltersManager,
+                        writeBufferManagerCapacity);
+        this.savepointRestoreOperation =
+                new FullSnapshotRestoreOperation<>(
+                        keyGroupRange,
+                        userCodeClassLoader,
+                        restoreStateHandles,
+                        keySerializerProvider);
+        this.registeredPQStates = registeredPQStates;
+        this.priorityQueueFactory = priorityQueueFactory;
+        this.numberOfKeyGroups = numberOfKeyGroups;
+        this.keyGroupRange = keyGroupRange;
+        this.keyGroupPrefixBytes =
+                CompositeKeySerializationUtils.computeRequiredBytesInKeyGroupPrefix(
+                        numberOfKeyGroups);
+    }
+
+    /** Restores all key-groups data that is referenced by the passed state handles. */
+    @Override
+    public RocksDBRestoreResult restore()
+            throws IOException, StateMigrationException, RocksDBException {
+        rocksHandle.openDB();
+        try (ThrowingIterator<SavepointRestoreResult> restore =
+                savepointRestoreOperation.restore()) {
+            while (restore.hasNext()) {
+                applyRestoreResult(restore.next());
+            }
+        }
+        return new RocksDBRestoreResult(
+                this.rocksHandle.getDb(),
+                this.rocksHandle.getDefaultColumnFamilyHandle(),
+                this.rocksHandle.getNativeMetricMonitor(),
+                -1,
+                null,
+                null);
+    }
+
+    private void applyRestoreResult(SavepointRestoreResult savepointRestoreResult)
+            throws IOException, RocksDBException, StateMigrationException {
+        List<StateMetaInfoSnapshot> restoredMetaInfos =
+                savepointRestoreResult.getStateMetaInfoSnapshots();
+        Map<Integer, ColumnFamilyHandle> columnFamilyHandles = new HashMap<>();
+        Map<Integer, HeapPriorityQueueSnapshotRestoreWrapper<?>> restoredPQStates = new HashMap<>();
+        for (int i = 0; i < restoredMetaInfos.size(); i++) {
+            StateMetaInfoSnapshot restoredMetaInfo = restoredMetaInfos.get(i);
+            if (restoredMetaInfo.getBackendStateType() == BackendStateType.PRIORITY_QUEUE) {
+                String stateName = restoredMetaInfo.getName();
+                HeapPriorityQueueSnapshotRestoreWrapper<?> queueWrapper =
+                        registeredPQStates.computeIfAbsent(
+                                stateName,
+                                key ->
+                                        createInternal(
+                                                new RegisteredPriorityQueueStateBackendMetaInfo<>(
+                                                        restoredMetaInfo)));
+                restoredPQStates.put(i, queueWrapper);
+            } else {
+                RocksDbKvStateInfo registeredStateCFHandle =
+                        this.rocksHandle.getOrRegisterStateColumnFamilyHandle(
+                                null, restoredMetaInfo);
+                columnFamilyHandles.put(i, registeredStateCFHandle.columnFamilyHandle);
+            }
+        }
+
+        try (ThrowingIterator<KeyGroup> keyGroups = savepointRestoreResult.getRestoredKeyGroups()) {
+            restoreKVStateData(keyGroups, columnFamilyHandles, restoredPQStates);
+        }
+    }
+
+    /**
+     * Restore the KV-state / ColumnFamily data for all key-groups referenced by the current state
+     * handle.
+     */
+    private void restoreKVStateData(
+            ThrowingIterator<KeyGroup> keyGroups,
+            Map<Integer, ColumnFamilyHandle> columnFamilies,
+            Map<Integer, HeapPriorityQueueSnapshotRestoreWrapper<?>> restoredPQStates)
+            throws IOException, RocksDBException, StateMigrationException {
+        // for all key-groups in the current state handle...
+        try (RocksDBWriteBatchWrapper writeBatchWrapper =
+                new RocksDBWriteBatchWrapper(this.rocksHandle.getDb(), writeBatchSize)) {
+            HeapPriorityQueueSnapshotRestoreWrapper<HeapPriorityQueueElement> restoredPQ = null;
+            ColumnFamilyHandle handle = null;
+            while (keyGroups.hasNext()) {
+                KeyGroup keyGroup = keyGroups.next();
+                try (ThrowingIterator<KeyGroupEntry> groupEntries = keyGroup.getKeyGroupEntries()) {
+                    int oldKvStateId = -1;
+                    while (groupEntries.hasNext()) {
+                        KeyGroupEntry groupEntry = groupEntries.next();
+                        int kvStateId = groupEntry.getKvStateId();
+                        if (kvStateId != oldKvStateId) {
+                            oldKvStateId = kvStateId;
+                            handle = columnFamilies.get(kvStateId);
+                            restoredPQ = getRestoredPQ(restoredPQStates, kvStateId);
+                        }
+                        if (restoredPQ != null) {
+                            restoreQueueElement(restoredPQ, groupEntry);
+                        } else if (handle != null) {
+                            writeBatchWrapper.put(
+                                    handle, groupEntry.getKey(), groupEntry.getValue());
+                        } else {
+                            throw new IllegalStateException("Unknown state id: " + kvStateId);
+                        }
+                    }
+                }
+            }
+        }
+    }
+
+    private void restoreQueueElement(
+            HeapPriorityQueueSnapshotRestoreWrapper<HeapPriorityQueueElement> restoredPQ,
+            KeyGroupEntry groupEntry)
+            throws IOException {
+        deserializer.setBuffer(groupEntry.getKey());
+        deserializer.skipBytesToRead(keyGroupPrefixBytes);
+        HeapPriorityQueueElement queueElement =
+                restoredPQ.getMetaInfo().getElementSerializer().deserialize(deserializer);
+        restoredPQ.getPriorityQueue().add(queueElement);
+    }
+
+    @SuppressWarnings("unchecked")
+    private HeapPriorityQueueSnapshotRestoreWrapper<HeapPriorityQueueElement> getRestoredPQ(
+            Map<Integer, HeapPriorityQueueSnapshotRestoreWrapper<?>> restoredPQStates,
+            int kvStateId) {
+        return (HeapPriorityQueueSnapshotRestoreWrapper<HeapPriorityQueueElement>)
+                restoredPQStates.get(kvStateId);
+    }
+
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    private <T extends HeapPriorityQueueElement & PriorityComparable<? super T> & Keyed<?>>
+            HeapPriorityQueueSnapshotRestoreWrapper<T> createInternal(
+                    RegisteredPriorityQueueStateBackendMetaInfo metaInfo) {
+
+        final String stateName = metaInfo.getName();
+        final HeapPriorityQueueSet<T> priorityQueue =
+                priorityQueueFactory.create(stateName, metaInfo.getElementSerializer());
+
+        return new HeapPriorityQueueSnapshotRestoreWrapper<>(
+                priorityQueue,
+                metaInfo,
+                KeyExtractorFunction.forKeyedObjects(),
+                keyGroupRange,
+                numberOfKeyGroups);
+    }
+
+    @Override
+    public void close() throws Exception {
+        this.rocksHandle.close();
+    }
+}
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksDBFullSnapshotResources.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksDBFullSnapshotResources.java
index eff0e8a..09937d6 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksDBFullSnapshotResources.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksDBFullSnapshotResources.java
@@ -22,8 +22,10 @@ import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend;
 import org.apache.flink.contrib.streaming.state.RocksIteratorWrapper;
+import org.apache.flink.contrib.streaming.state.iterator.RocksQueueIterator;
 import org.apache.flink.contrib.streaming.state.iterator.RocksStatesPerKeyGroupMergeIterator;
 import org.apache.flink.contrib.streaming.state.iterator.RocksTransformingIteratorWrapper;
+import org.apache.flink.contrib.streaming.state.iterator.SingleStateIterator;
 import org.apache.flink.core.fs.CloseableRegistry;
 import org.apache.flink.runtime.state.FullSnapshotResources;
 import org.apache.flink.runtime.state.KeyGroupRange;
@@ -31,6 +33,7 @@ import org.apache.flink.runtime.state.KeyValueStateIterator;
 import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
 import org.apache.flink.runtime.state.StateSnapshotTransformer;
 import org.apache.flink.runtime.state.StreamCompressionDecorator;
+import org.apache.flink.runtime.state.heap.HeapPriorityQueueStateSnapshot;
 import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
 import org.apache.flink.util.IOUtils;
 import org.apache.flink.util.ResourceGuard;
@@ -61,11 +64,13 @@ class RocksDBFullSnapshotResources<K> implements FullSnapshotResources<K> {
     private final KeyGroupRange keyGroupRange;
     private final TypeSerializer<K> keySerializer;
     private final StreamCompressionDecorator streamCompressionDecorator;
+    private final List<HeapPriorityQueueStateSnapshot<?>> heapPriorityQueuesSnapshots;
 
     public RocksDBFullSnapshotResources(
             ResourceGuard.Lease lease,
             Snapshot snapshot,
             List<RocksDBKeyedStateBackend.RocksDbKvStateInfo> metaDataCopy,
+            List<HeapPriorityQueueStateSnapshot<?>> heapPriorityQueuesSnapshots,
             List<StateMetaInfoSnapshot> stateMetaInfoSnapshots,
             RocksDB db,
             int keyGroupPrefixBytes,
@@ -75,6 +80,7 @@ class RocksDBFullSnapshotResources<K> implements FullSnapshotResources<K> {
         this.lease = lease;
         this.snapshot = snapshot;
         this.stateMetaInfoSnapshots = stateMetaInfoSnapshots;
+        this.heapPriorityQueuesSnapshots = heapPriorityQueuesSnapshots;
         this.db = db;
         this.keyGroupPrefixBytes = keyGroupPrefixBytes;
         this.keyGroupRange = keyGroupRange;
@@ -115,10 +121,16 @@ class RocksDBFullSnapshotResources<K> implements FullSnapshotResources<K> {
             List<Tuple2<RocksIteratorWrapper, Integer>> kvStateIterators =
                     createKVStateIterators(closeableRegistry, readOptions);
 
+            List<SingleStateIterator> heapPriorityQueueIterators =
+                    createHeapPriorityQueueIterators();
+
             // Here we transfer ownership of the required resources to the
             // RocksStatesPerKeyGroupMergeIterator
             return new RocksStatesPerKeyGroupMergeIterator(
-                    closeableRegistry, new ArrayList<>(kvStateIterators), keyGroupPrefixBytes);
+                    closeableRegistry,
+                    kvStateIterators,
+                    heapPriorityQueueIterators,
+                    keyGroupPrefixBytes);
         } catch (Throwable t) {
             // If anything goes wrong, clean up our stuff. If things went smoothly the
             // merging iterator is now responsible for closing the resources
@@ -127,6 +139,18 @@ class RocksDBFullSnapshotResources<K> implements FullSnapshotResources<K> {
         }
     }
 
+    private List<SingleStateIterator> createHeapPriorityQueueIterators() {
+        int kvStateId = metaData.size();
+        List<SingleStateIterator> queuesIterators =
+                new ArrayList<>(heapPriorityQueuesSnapshots.size());
+        for (HeapPriorityQueueStateSnapshot<?> queuesSnapshot : heapPriorityQueuesSnapshots) {
+            queuesIterators.add(
+                    new RocksQueueIterator(
+                            queuesSnapshot, keyGroupRange, keyGroupPrefixBytes, kvStateId++));
+        }
+        return queuesIterators;
+    }
+
     private List<Tuple2<RocksIteratorWrapper, Integer>> createKVStateIterators(
             CloseableRegistry closeableRegistry, ReadOptions readOptions) throws IOException {
 
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksFullSnapshotStrategy.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksFullSnapshotStrategy.java
index 0b83bd8..b3318f2 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksFullSnapshotStrategy.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksFullSnapshotStrategy.java
@@ -31,6 +31,8 @@ import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.LocalRecoveryConfig;
 import org.apache.flink.runtime.state.SnapshotResult;
 import org.apache.flink.runtime.state.StreamCompressionDecorator;
+import org.apache.flink.runtime.state.heap.HeapPriorityQueueSnapshotRestoreWrapper;
+import org.apache.flink.runtime.state.heap.HeapPriorityQueueStateSnapshot;
 import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
 import org.apache.flink.util.ResourceGuard;
 import org.apache.flink.util.function.SupplierWithException;
@@ -64,11 +66,17 @@ public class RocksFullSnapshotStrategy<K>
     /** This decorator is used to apply compression per key-group for the written snapshot data. */
     @Nonnull private final StreamCompressionDecorator keyGroupCompressionDecorator;
 
+    private final LinkedHashMap<String, HeapPriorityQueueSnapshotRestoreWrapper<?>>
+            registeredPQStates;
+
     public RocksFullSnapshotStrategy(
             @Nonnull RocksDB db,
             @Nonnull ResourceGuard rocksDBResourceGuard,
             @Nonnull TypeSerializer<K> keySerializer,
             @Nonnull LinkedHashMap<String, RocksDbKvStateInfo> kvStateInformation,
+            @Nonnull
+                    LinkedHashMap<String, HeapPriorityQueueSnapshotRestoreWrapper<?>>
+                            registeredPQStates,
             @Nonnull KeyGroupRange keyGroupRange,
             @Nonnegative int keyGroupPrefixBytes,
             @Nonnull LocalRecoveryConfig localRecoveryConfig,
@@ -84,6 +92,7 @@ public class RocksFullSnapshotStrategy<K>
                 localRecoveryConfig);
 
         this.keyGroupCompressionDecorator = keyGroupCompressionDecorator;
+        this.registeredPQStates = registeredPQStates;
     }
 
     @Override
@@ -99,6 +108,13 @@ public class RocksFullSnapshotStrategy<K>
             metaDataCopy.add(stateInfo);
         }
 
+        List<HeapPriorityQueueStateSnapshot<?>> heapPriorityQueuesSnapshots =
+                new ArrayList<>(registeredPQStates.size());
+        for (HeapPriorityQueueSnapshotRestoreWrapper<?> stateInfo : registeredPQStates.values()) {
+            stateMetaInfoSnapshots.add(stateInfo.getMetaInfo().snapshot());
+            heapPriorityQueuesSnapshots.add(stateInfo.stateSnapshot());
+        }
+
         final ResourceGuard.Lease lease = rocksDBResourceGuard.acquireResource();
         final Snapshot snapshot = db.getSnapshot();
 
@@ -106,6 +122,7 @@ public class RocksFullSnapshotStrategy<K>
                 lease,
                 snapshot,
                 metaDataCopy,
+                heapPriorityQueuesSnapshots,
                 stateMetaInfoSnapshots,
                 db,
                 keyGroupPrefixBytes,
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksKeyGroupsRocksSingleStateIteratorTest.java b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksKeyGroupsRocksSingleStateIteratorTest.java
index 6361abc..cac9d61 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksKeyGroupsRocksSingleStateIteratorTest.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksKeyGroupsRocksSingleStateIteratorTest.java
@@ -60,7 +60,10 @@ public class RocksKeyGroupsRocksSingleStateIteratorTest {
     public void testEmptyMergeIterator() throws Exception {
         RocksStatesPerKeyGroupMergeIterator emptyIterator =
                 new RocksStatesPerKeyGroupMergeIterator(
-                        new CloseableRegistry(), Collections.emptyList(), 2);
+                        new CloseableRegistry(),
+                        Collections.emptyList(),
+                        Collections.emptyList(),
+                        2);
         Assert.assertFalse(emptyIterator.isValid());
     }
 
@@ -134,6 +137,7 @@ public class RocksKeyGroupsRocksSingleStateIteratorTest {
                     new RocksStatesPerKeyGroupMergeIterator(
                             closeableRegistry,
                             rocksIteratorsWithKVStateId,
+                            Collections.emptyList(),
                             maxParallelism <= Byte.MAX_VALUE ? 1 : 2)) {
 
                 int prevKVState = -1;
diff --git a/flink-tests/src/test/java/org/apache/flink/test/state/BackendSwitchSpecs.java b/flink-tests/src/test/java/org/apache/flink/test/state/BackendSwitchSpecs.java
index 2cccafd..a72ca84 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/state/BackendSwitchSpecs.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/state/BackendSwitchSpecs.java
@@ -24,7 +24,7 @@ import org.apache.flink.api.common.typeutils.base.StringSerializer;
 import org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend;
 import org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackendBuilder;
 import org.apache.flink.contrib.streaming.state.RocksDBResourceContainer;
-import org.apache.flink.contrib.streaming.state.RocksDBStateBackend;
+import org.apache.flink.contrib.streaming.state.RocksDBStateBackend.PriorityQueueStateType;
 import org.apache.flink.core.fs.CloseableRegistry;
 import org.apache.flink.metrics.groups.UnregisteredMetricsGroup;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
@@ -66,7 +66,10 @@ public final class BackendSwitchSpecs {
     }
 
     /** Specification for a {@link RocksDBKeyedStateBackend}. */
-    static final BackendSwitchSpec ROCKS = new RocksSpec();
+    static final BackendSwitchSpec ROCKS = new RocksSpec(PriorityQueueStateType.ROCKSDB);
+
+    /** Specification for a {@link RocksDBKeyedStateBackend} which stores its timers on heap. */
+    static final BackendSwitchSpec ROCKS_HEAP_TIMERS = new RocksSpec(PriorityQueueStateType.HEAP);
 
     /** Specification for a {@link HeapKeyedStateBackend}. */
     static final BackendSwitchSpec HEAP = new HeapSpec();
@@ -74,6 +77,11 @@ public final class BackendSwitchSpecs {
     private static final class RocksSpec implements BackendSwitchSpec {
 
         private final TemporaryFolder temporaryFolder = new TemporaryFolder();
+        private final PriorityQueueStateType queueStateType;
+
+        public RocksSpec(PriorityQueueStateType queueStateType) {
+            this.queueStateType = queueStateType;
+        }
 
         @Override
         public CheckpointableKeyedStateBackend<String> createBackend(
@@ -97,7 +105,7 @@ public final class BackendSwitchSpecs {
                             keyGroupRange,
                             new ExecutionConfig(),
                             TestLocalRecoveryConfig.disabled(),
-                            RocksDBStateBackend.PriorityQueueStateType.ROCKSDB,
+                            queueStateType,
                             TtlTimeProvider.DEFAULT,
                             new UnregisteredMetricsGroup(),
                             stateHandles,
@@ -113,7 +121,7 @@ public final class BackendSwitchSpecs {
 
         @Override
         public String toString() {
-            return "ROCKS";
+            return "ROCKS(" + queueStateType + ")";
         }
     }
 
diff --git a/flink-tests/src/test/java/org/apache/flink/test/state/RocksSavepointStateBackendSwitchTest.java b/flink-tests/src/test/java/org/apache/flink/test/state/RocksSavepointStateBackendSwitchTest.java
index d7f7e2d..7a52c55 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/state/RocksSavepointStateBackendSwitchTest.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/state/RocksSavepointStateBackendSwitchTest.java
@@ -25,16 +25,28 @@ import org.junit.runners.Parameterized;
 
 import java.util.Arrays;
 import java.util.Collection;
+import java.util.List;
+import java.util.stream.Collectors;
 
 /** Tests for switching a RocksDB state backend to a different one. */
 @RunWith(Parameterized.class)
 public class RocksSavepointStateBackendSwitchTest extends SavepointStateBackendSwitchTestBase {
-    public RocksSavepointStateBackendSwitchTest(BackendSwitchSpec toBackend) {
-        super(BackendSwitchSpecs.ROCKS, toBackend);
+    public RocksSavepointStateBackendSwitchTest(
+            BackendSwitchSpec fromBackend, BackendSwitchSpec toBackend) {
+        super(fromBackend, toBackend);
     }
 
-    @Parameterized.Parameters(name = "to: {0}")
-    public static Collection<BackendSwitchSpec> targetBackends() {
-        return Arrays.asList(BackendSwitchSpecs.HEAP, BackendSwitchSpecs.ROCKS);
+    @Parameterized.Parameters(name = "from: {0} to: {1}")
+    public static Collection<BackendSwitchSpec[]> targetBackends() {
+        List<BackendSwitchSpec> fromBackends =
+                Arrays.asList(BackendSwitchSpecs.ROCKS_HEAP_TIMERS, BackendSwitchSpecs.ROCKS);
+        List<BackendSwitchSpec> toBackends =
+                Arrays.asList(
+                        BackendSwitchSpecs.HEAP,
+                        BackendSwitchSpecs.ROCKS,
+                        BackendSwitchSpecs.ROCKS_HEAP_TIMERS);
+        return fromBackends.stream()
+                .flatMap(from -> toBackends.stream().map(to -> new BackendSwitchSpec[] {from, to}))
+                .collect(Collectors.toList());
     }
 }