You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by dw...@apache.org on 2021/02/22 10:45:20 UTC

[flink] 02/03: [FLINK-21206] Write savepoints in unified format from HeapStateBackend

This is an automated email from the ASF dual-hosted git repository.

dwysakowicz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 7564c810e55f952ea7014a707c9487b777131c2d
Author: Dawid Wysakowicz <dw...@apache.org>
AuthorDate: Wed Jan 27 12:34:27 2021 +0100

    [FLINK-21206] Write savepoints in unified format from HeapStateBackend
    
    This closes #14925
---
 .../runtime/state/heap/HeapKeyedStateBackend.java  |  19 ++-
 .../state/heap/HeapKeyedStateBackendBuilder.java   |  16 +-
 .../runtime/state/heap/HeapSavepointStrategy.java  |  95 +++++++++++
 .../runtime/state/heap/HeapSnapshotResources.java  | 190 +++++++++++++++++++++
 .../runtime/state/heap/HeapSnapshotStrategy.java   | 116 ++-----------
 flink-tests/pom.xml                                |   9 +
 .../state/HeapSavepointStateBackendSwitchTest.java |   4 +-
 7 files changed, 338 insertions(+), 111 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index c81bb68..0b42a32 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -106,7 +106,9 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
      * The snapshot strategy for this backend. This determines, e.g., if snapshots are synchronous
      * or asynchronous.
      */
-    private final SnapshotStrategyRunner<KeyedStateHandle, ?> snapshotStrategyRunner;
+    private final SnapshotStrategyRunner<KeyedStateHandle, ?> checkpointStrategyRunner;
+
+    private final SnapshotStrategyRunner<KeyedStateHandle, ?> savepointStrategyRunner;
 
     private final StateTableFactory<K> stateTableFactory;
 
@@ -125,7 +127,8 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
             Map<String, HeapPriorityQueueSnapshotRestoreWrapper<?>> registeredPQStates,
             LocalRecoveryConfig localRecoveryConfig,
             HeapPriorityQueueSetFactory priorityQueueSetFactory,
-            SnapshotStrategyRunner<KeyedStateHandle, ?> snapshotStrategyRunner,
+            SnapshotStrategyRunner<KeyedStateHandle, ?> checkpointStrategyRunner,
+            SnapshotStrategyRunner<KeyedStateHandle, ?> savepointStrategyRunner,
             StateTableFactory<K> stateTableFactory,
             InternalKeyContext<K> keyContext) {
         super(
@@ -141,7 +144,8 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
         this.registeredPQStates = registeredPQStates;
         this.localRecoveryConfig = localRecoveryConfig;
         this.priorityQueueSetFactory = priorityQueueSetFactory;
-        this.snapshotStrategyRunner = snapshotStrategyRunner;
+        this.checkpointStrategyRunner = checkpointStrategyRunner;
+        this.savepointStrategyRunner = savepointStrategyRunner;
         this.stateTableFactory = stateTableFactory;
         LOG.info("Initializing heap keyed state backend with stream factory.");
     }
@@ -356,8 +360,13 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
             @Nonnull CheckpointOptions checkpointOptions)
             throws Exception {
 
-        return snapshotStrategyRunner.snapshot(
-                checkpointId, timestamp, streamFactory, checkpointOptions);
+        if (checkpointOptions.getCheckpointType().isSavepoint()) {
+            return savepointStrategyRunner.snapshot(
+                    checkpointId, timestamp, streamFactory, checkpointOptions);
+        } else {
+            return checkpointStrategyRunner.snapshot(
+                    checkpointId, timestamp, streamFactory, checkpointOptions);
+        }
     }
 
     @Override
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendBuilder.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendBuilder.java
index 6519461..3991ba0 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendBuilder.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendBuilder.java
@@ -96,6 +96,14 @@ public class HeapKeyedStateBackendBuilder<K> extends AbstractKeyedStateBackendBu
         CloseableRegistry cancelStreamRegistryForBackend = new CloseableRegistry();
         HeapSnapshotStrategy<K> snapshotStrategy =
                 initSnapshotStrategy(registeredKVStates, registeredPQStates);
+        HeapSavepointStrategy<K> savepointStrategy =
+                new HeapSavepointStrategy<>(
+                        registeredKVStates,
+                        registeredPQStates,
+                        keyGroupCompressionDecorator,
+                        keyGroupRange,
+                        keySerializerProvider,
+                        numberOfKeyGroups);
         InternalKeyContext<K> keyContext =
                 new InternalKeyContextImpl<>(keyGroupRange, numberOfKeyGroups);
 
@@ -124,6 +132,11 @@ public class HeapKeyedStateBackendBuilder<K> extends AbstractKeyedStateBackendBu
                         snapshotStrategy,
                         cancelStreamRegistryForBackend,
                         asynchronousSnapshots ? ASYNCHRONOUS : SYNCHRONOUS),
+                new SnapshotStrategyRunner<>(
+                        "Heap backend savepoint",
+                        savepointStrategy,
+                        cancelStreamRegistryForBackend,
+                        asynchronousSnapshots ? ASYNCHRONOUS : SYNCHRONOUS),
                 stateTableFactory,
                 keyContext);
     }
@@ -187,6 +200,7 @@ public class HeapKeyedStateBackendBuilder<K> extends AbstractKeyedStateBackendBu
                 keyGroupCompressionDecorator,
                 localRecoveryConfig,
                 keyGroupRange,
-                keySerializerProvider);
+                keySerializerProvider,
+                numberOfKeyGroups);
     }
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapSavepointStrategy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapSavepointStrategy.java
new file mode 100644
index 0000000..ed70ed7
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapSavepointStrategy.java
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.heap;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.runtime.state.CheckpointStreamWithResultProvider;
+import org.apache.flink.runtime.state.CheckpointedStateScope;
+import org.apache.flink.runtime.state.FullSnapshotAsyncWriter;
+import org.apache.flink.runtime.state.FullSnapshotResources;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyedStateHandle;
+import org.apache.flink.runtime.state.SnapshotStrategy;
+import org.apache.flink.runtime.state.StateSerializerProvider;
+import org.apache.flink.runtime.state.StreamCompressionDecorator;
+
+import javax.annotation.Nonnull;
+
+import java.util.Map;
+
+/** A strategy how to perform a snapshot of a {@link HeapKeyedStateBackend}. */
+class HeapSavepointStrategy<K>
+        implements SnapshotStrategy<KeyedStateHandle, FullSnapshotResources<K>> {
+
+    private final Map<String, StateTable<K, ?, ?>> registeredKVStates;
+    private final Map<String, HeapPriorityQueueSnapshotRestoreWrapper<?>> registeredPQStates;
+    private final StreamCompressionDecorator keyGroupCompressionDecorator;
+    private final KeyGroupRange keyGroupRange;
+    private final StateSerializerProvider<K> keySerializerProvider;
+    private final int totalKeyGroups;
+
+    HeapSavepointStrategy(
+            Map<String, StateTable<K, ?, ?>> registeredKVStates,
+            Map<String, HeapPriorityQueueSnapshotRestoreWrapper<?>> registeredPQStates,
+            StreamCompressionDecorator keyGroupCompressionDecorator,
+            KeyGroupRange keyGroupRange,
+            StateSerializerProvider<K> keySerializerProvider,
+            int totalKeyGroups) {
+        this.registeredKVStates = registeredKVStates;
+        this.registeredPQStates = registeredPQStates;
+        this.keyGroupCompressionDecorator = keyGroupCompressionDecorator;
+        this.keyGroupRange = keyGroupRange;
+        this.keySerializerProvider = keySerializerProvider;
+        this.totalKeyGroups = totalKeyGroups;
+    }
+
+    @Override
+    public FullSnapshotResources<K> syncPrepareResources(long checkpointId) {
+        return HeapSnapshotResources.create(
+                registeredKVStates,
+                registeredPQStates,
+                keyGroupCompressionDecorator,
+                keyGroupRange,
+                getKeySerializer(),
+                totalKeyGroups);
+    }
+
+    @Override
+    public SnapshotResultSupplier<KeyedStateHandle> asyncSnapshot(
+            FullSnapshotResources<K> syncPartResource,
+            long checkpointId,
+            long timestamp,
+            @Nonnull CheckpointStreamFactory streamFactory,
+            @Nonnull CheckpointOptions checkpointOptions) {
+
+        assert checkpointOptions.getCheckpointType().isSavepoint();
+        return new FullSnapshotAsyncWriter<>(
+                checkpointOptions.getCheckpointType(),
+                () ->
+                        CheckpointStreamWithResultProvider.createSimpleStream(
+                                CheckpointedStateScope.EXCLUSIVE, streamFactory),
+                syncPartResource);
+    }
+
+    public TypeSerializer<K> getKeySerializer() {
+        return keySerializerProvider.currentSchemaSerializer();
+    }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapSnapshotResources.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapSnapshotResources.java
new file mode 100644
index 0000000..aa2c0af
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapSnapshotResources.java
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.heap;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.state.FullSnapshotResources;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyValueStateIterator;
+import org.apache.flink.runtime.state.StateSnapshot;
+import org.apache.flink.runtime.state.StateSnapshotRestore;
+import org.apache.flink.runtime.state.StreamCompressionDecorator;
+import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * A set of resources required to take a checkpoint or savepoint from a {@link
+ * HeapKeyedStateBackend}.
+ */
+@Internal
+final class HeapSnapshotResources<K> implements FullSnapshotResources<K> {
+    private final List<StateMetaInfoSnapshot> metaInfoSnapshots;
+    private final Map<StateUID, StateSnapshot> cowStateStableSnapshots;
+    private final StreamCompressionDecorator streamCompressionDecorator;
+    private final Map<StateUID, Integer> stateNamesToId;
+    private final KeyGroupRange keyGroupRange;
+    private final TypeSerializer<K> keySerializer;
+    private final int totalKeyGroups;
+
+    private HeapSnapshotResources(
+            List<StateMetaInfoSnapshot> metaInfoSnapshots,
+            Map<StateUID, StateSnapshot> cowStateStableSnapshots,
+            StreamCompressionDecorator streamCompressionDecorator,
+            Map<StateUID, Integer> stateNamesToId,
+            KeyGroupRange keyGroupRange,
+            TypeSerializer<K> keySerializer,
+            int totalKeyGroups) {
+        this.metaInfoSnapshots = metaInfoSnapshots;
+        this.cowStateStableSnapshots = cowStateStableSnapshots;
+        this.streamCompressionDecorator = streamCompressionDecorator;
+        this.stateNamesToId = stateNamesToId;
+        this.keyGroupRange = keyGroupRange;
+        this.keySerializer = keySerializer;
+        this.totalKeyGroups = totalKeyGroups;
+    }
+
+    public static <K> HeapSnapshotResources<K> create(
+            Map<String, StateTable<K, ?, ?>> registeredKVStates,
+            Map<String, HeapPriorityQueueSnapshotRestoreWrapper<?>> registeredPQStates,
+            StreamCompressionDecorator streamCompressionDecorator,
+            KeyGroupRange keyGroupRange,
+            TypeSerializer<K> keySerializer,
+            int totalKeyGroups) {
+
+        if (registeredKVStates.isEmpty() && registeredPQStates.isEmpty()) {
+            return new HeapSnapshotResources<>(
+                    Collections.emptyList(),
+                    Collections.emptyMap(),
+                    streamCompressionDecorator,
+                    Collections.emptyMap(),
+                    keyGroupRange,
+                    keySerializer,
+                    totalKeyGroups);
+        }
+
+        int numStates = registeredKVStates.size() + registeredPQStates.size();
+
+        Preconditions.checkState(
+                numStates <= Short.MAX_VALUE,
+                "Too many states: "
+                        + numStates
+                        + ". Currently at most "
+                        + Short.MAX_VALUE
+                        + " states are supported");
+
+        final List<StateMetaInfoSnapshot> metaInfoSnapshots = new ArrayList<>(numStates);
+        final Map<StateUID, Integer> stateNamesToId = new HashMap<>(numStates);
+        final Map<StateUID, StateSnapshot> cowStateStableSnapshots = new HashMap<>(numStates);
+
+        processSnapshotMetaInfoForAllStates(
+                metaInfoSnapshots,
+                cowStateStableSnapshots,
+                stateNamesToId,
+                registeredKVStates,
+                StateMetaInfoSnapshot.BackendStateType.KEY_VALUE);
+
+        processSnapshotMetaInfoForAllStates(
+                metaInfoSnapshots,
+                cowStateStableSnapshots,
+                stateNamesToId,
+                registeredPQStates,
+                StateMetaInfoSnapshot.BackendStateType.PRIORITY_QUEUE);
+
+        return new HeapSnapshotResources<>(
+                metaInfoSnapshots,
+                cowStateStableSnapshots,
+                streamCompressionDecorator,
+                stateNamesToId,
+                keyGroupRange,
+                keySerializer,
+                totalKeyGroups);
+    }
+
+    private static void processSnapshotMetaInfoForAllStates(
+            List<StateMetaInfoSnapshot> metaInfoSnapshots,
+            Map<StateUID, StateSnapshot> cowStateStableSnapshots,
+            Map<StateUID, Integer> stateNamesToId,
+            Map<String, ? extends StateSnapshotRestore> registeredStates,
+            StateMetaInfoSnapshot.BackendStateType stateType) {
+
+        for (Map.Entry<String, ? extends StateSnapshotRestore> kvState :
+                registeredStates.entrySet()) {
+            final StateUID stateUid = StateUID.of(kvState.getKey(), stateType);
+            stateNamesToId.put(stateUid, stateNamesToId.size());
+            StateSnapshotRestore state = kvState.getValue();
+            if (null != state) {
+                final StateSnapshot stateSnapshot = state.stateSnapshot();
+                metaInfoSnapshots.add(stateSnapshot.getMetaInfoSnapshot());
+                cowStateStableSnapshots.put(stateUid, stateSnapshot);
+            }
+        }
+    }
+
+    @Override
+    public void release() {
+        for (StateSnapshot stateSnapshot : cowStateStableSnapshots.values()) {
+            stateSnapshot.release();
+        }
+    }
+
+    public List<StateMetaInfoSnapshot> getMetaInfoSnapshots() {
+        return metaInfoSnapshots;
+    }
+
+    @Override
+    public KeyValueStateIterator createKVStateIterator() throws IOException {
+        return new HeapKeyValueStateIterator(
+                keyGroupRange,
+                keySerializer,
+                totalKeyGroups,
+                stateNamesToId,
+                cowStateStableSnapshots);
+    }
+
+    @Override
+    public KeyGroupRange getKeyGroupRange() {
+        return keyGroupRange;
+    }
+
+    @Override
+    public TypeSerializer<K> getKeySerializer() {
+        return keySerializer;
+    }
+
+    @Override
+    public StreamCompressionDecorator getStreamCompressionDecorator() {
+        return streamCompressionDecorator;
+    }
+
+    public Map<StateUID, StateSnapshot> getCowStateStableSnapshots() {
+        return cowStateStableSnapshots;
+    }
+
+    public Map<StateUID, Integer> getStateNamesToId() {
+        return stateNamesToId;
+    }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapSnapshotStrategy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapSnapshotStrategy.java
index fe48c90..78b96e5 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapSnapshotStrategy.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapSnapshotStrategy.java
@@ -30,26 +30,20 @@ import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedBackendSerializationProxy;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.LocalRecoveryConfig;
-import org.apache.flink.runtime.state.SnapshotResources;
 import org.apache.flink.runtime.state.SnapshotResult;
 import org.apache.flink.runtime.state.SnapshotStrategy;
 import org.apache.flink.runtime.state.StateSerializerProvider;
 import org.apache.flink.runtime.state.StateSnapshot;
-import org.apache.flink.runtime.state.StateSnapshotRestore;
 import org.apache.flink.runtime.state.StreamCompressionDecorator;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.UncompressedStreamCompressionDecorator;
 import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
-import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.function.SupplierWithException;
 
 import javax.annotation.Nonnull;
 
 import java.io.IOException;
 import java.io.OutputStream;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
@@ -60,7 +54,7 @@ import static org.apache.flink.runtime.state.CheckpointStreamWithResultProvider.
 
 /** A strategy how to perform a snapshot of a {@link HeapKeyedStateBackend}. */
 class HeapSnapshotStrategy<K>
-        implements SnapshotStrategy<KeyedStateHandle, HeapSnapshotStrategy.HeapSnapshotResources> {
+        implements SnapshotStrategy<KeyedStateHandle, HeapSnapshotResources<K>> {
 
     private final Map<String, StateTable<K, ?, ?>> registeredKVStates;
     private final Map<String, HeapPriorityQueueSnapshotRestoreWrapper<?>> registeredPQStates;
@@ -68,6 +62,7 @@ class HeapSnapshotStrategy<K>
     private final LocalRecoveryConfig localRecoveryConfig;
     private final KeyGroupRange keyGroupRange;
     private final StateSerializerProvider<K> keySerializerProvider;
+    private final int totalKeyGroups;
 
     HeapSnapshotStrategy(
             Map<String, StateTable<K, ?, ?>> registeredKVStates,
@@ -75,58 +70,31 @@ class HeapSnapshotStrategy<K>
             StreamCompressionDecorator keyGroupCompressionDecorator,
             LocalRecoveryConfig localRecoveryConfig,
             KeyGroupRange keyGroupRange,
-            StateSerializerProvider<K> keySerializerProvider) {
+            StateSerializerProvider<K> keySerializerProvider,
+            int totalKeyGroups) {
         this.registeredKVStates = registeredKVStates;
         this.registeredPQStates = registeredPQStates;
         this.keyGroupCompressionDecorator = keyGroupCompressionDecorator;
         this.localRecoveryConfig = localRecoveryConfig;
         this.keyGroupRange = keyGroupRange;
         this.keySerializerProvider = keySerializerProvider;
+        this.totalKeyGroups = totalKeyGroups;
     }
 
     @Override
-    public HeapSnapshotResources syncPrepareResources(long checkpointId) {
-
-        if (!hasRegisteredState()) {
-            return new HeapSnapshotResources(
-                    Collections.emptyList(), Collections.emptyMap(), Collections.emptyMap());
-        }
-
-        int numStates = registeredKVStates.size() + registeredPQStates.size();
-
-        Preconditions.checkState(
-                numStates <= Short.MAX_VALUE,
-                "Too many states: "
-                        + numStates
-                        + ". Currently at most "
-                        + Short.MAX_VALUE
-                        + " states are supported");
-
-        final List<StateMetaInfoSnapshot> metaInfoSnapshots = new ArrayList<>(numStates);
-        final Map<StateUID, Integer> stateNamesToId = new HashMap<>(numStates);
-        final Map<StateUID, StateSnapshot> cowStateStableSnapshots = new HashMap<>(numStates);
-
-        processSnapshotMetaInfoForAllStates(
-                metaInfoSnapshots,
-                cowStateStableSnapshots,
-                stateNamesToId,
+    public HeapSnapshotResources<K> syncPrepareResources(long checkpointId) {
+        return HeapSnapshotResources.create(
                 registeredKVStates,
-                StateMetaInfoSnapshot.BackendStateType.KEY_VALUE);
-
-        processSnapshotMetaInfoForAllStates(
-                metaInfoSnapshots,
-                cowStateStableSnapshots,
-                stateNamesToId,
                 registeredPQStates,
-                StateMetaInfoSnapshot.BackendStateType.PRIORITY_QUEUE);
-
-        return new HeapSnapshotResources(
-                metaInfoSnapshots, cowStateStableSnapshots, stateNamesToId);
+                keyGroupCompressionDecorator,
+                keyGroupRange,
+                getKeySerializer(),
+                totalKeyGroups);
     }
 
     @Override
     public SnapshotResultSupplier<KeyedStateHandle> asyncSnapshot(
-            HeapSnapshotResources syncPartResource,
+            HeapSnapshotResources<K> syncPartResource,
             long checkpointId,
             long timestamp,
             @Nonnull CheckpointStreamFactory streamFactory,
@@ -142,7 +110,7 @@ class HeapSnapshotStrategy<K>
                         // TODO: this code assumes that writing a serializer is threadsafe, we
                         // should support to
                         // get a serialized form already at state registration time in the future
-                        getKeySerializer(),
+                        syncPartResource.getKeySerializer(),
                         metaInfoSnapshots,
                         !Objects.equals(
                                 UncompressedStreamCompressionDecorator.INSTANCE,
@@ -214,65 +182,7 @@ class HeapSnapshotStrategy<K>
         };
     }
 
-    private void processSnapshotMetaInfoForAllStates(
-            List<StateMetaInfoSnapshot> metaInfoSnapshots,
-            Map<StateUID, StateSnapshot> cowStateStableSnapshots,
-            Map<StateUID, Integer> stateNamesToId,
-            Map<String, ? extends StateSnapshotRestore> registeredStates,
-            StateMetaInfoSnapshot.BackendStateType stateType) {
-
-        for (Map.Entry<String, ? extends StateSnapshotRestore> kvState :
-                registeredStates.entrySet()) {
-            final StateUID stateUid = StateUID.of(kvState.getKey(), stateType);
-            stateNamesToId.put(stateUid, stateNamesToId.size());
-            StateSnapshotRestore state = kvState.getValue();
-            if (null != state) {
-                final StateSnapshot stateSnapshot = state.stateSnapshot();
-                metaInfoSnapshots.add(stateSnapshot.getMetaInfoSnapshot());
-                cowStateStableSnapshots.put(stateUid, stateSnapshot);
-            }
-        }
-    }
-
-    private boolean hasRegisteredState() {
-        return !(registeredKVStates.isEmpty() && registeredPQStates.isEmpty());
-    }
-
     public TypeSerializer<K> getKeySerializer() {
         return keySerializerProvider.currentSchemaSerializer();
     }
-
-    static class HeapSnapshotResources implements SnapshotResources {
-        private final List<StateMetaInfoSnapshot> metaInfoSnapshots;
-        private final Map<StateUID, StateSnapshot> cowStateStableSnapshots;
-        private final Map<StateUID, Integer> stateNamesToId;
-
-        HeapSnapshotResources(
-                @Nonnull List<StateMetaInfoSnapshot> metaInfoSnapshots,
-                @Nonnull Map<StateUID, StateSnapshot> cowStateStableSnapshots,
-                @Nonnull Map<StateUID, Integer> stateNamesToId) {
-            this.metaInfoSnapshots = metaInfoSnapshots;
-            this.cowStateStableSnapshots = cowStateStableSnapshots;
-            this.stateNamesToId = stateNamesToId;
-        }
-
-        @Override
-        public void release() {
-            for (StateSnapshot stateSnapshot : cowStateStableSnapshots.values()) {
-                stateSnapshot.release();
-            }
-        }
-
-        public List<StateMetaInfoSnapshot> getMetaInfoSnapshots() {
-            return metaInfoSnapshots;
-        }
-
-        public Map<StateUID, StateSnapshot> getCowStateStableSnapshots() {
-            return cowStateStableSnapshots;
-        }
-
-        public Map<StateUID, Integer> getStateNamesToId() {
-            return stateNamesToId;
-        }
-    }
 }
diff --git a/flink-tests/pom.xml b/flink-tests/pom.xml
index 7031344..fb30a1d 100644
--- a/flink-tests/pom.xml
+++ b/flink-tests/pom.xml
@@ -231,6 +231,15 @@ under the License.
 			<scope>test</scope>
 		</dependency>
 
+
+		<dependency>
+			<groupId>org.apache.flink</groupId>
+			<artifactId>flink-statebackend-rocksdb_${scala.binary.version}</artifactId>
+			<version>${project.version}</version>
+			<type>test-jar</type>
+			<scope>test</scope>
+		</dependency>
+
 		<dependency>
 			<groupId>com.github.oshi</groupId>
 			<artifactId>oshi-core</artifactId>
diff --git a/flink-tests/src/test/java/org/apache/flink/test/state/HeapSavepointStateBackendSwitchTest.java b/flink-tests/src/test/java/org/apache/flink/test/state/HeapSavepointStateBackendSwitchTest.java
index d5708a6..5138e03 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/state/HeapSavepointStateBackendSwitchTest.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/state/HeapSavepointStateBackendSwitchTest.java
@@ -23,8 +23,8 @@ import org.apache.flink.test.state.BackendSwitchSpecs.BackendSwitchSpec;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 
+import java.util.Arrays;
 import java.util.Collection;
-import java.util.Collections;
 
 /** Tests for switching a HEAP state backend to a different one. */
 @RunWith(Parameterized.class)
@@ -35,6 +35,6 @@ public class HeapSavepointStateBackendSwitchTest extends SavepointStateBackendSw
 
     @Parameterized.Parameters
     public static Collection<BackendSwitchSpec> targetBackends() {
-        return Collections.singletonList(BackendSwitchSpecs.HEAP);
+        return Arrays.asList(BackendSwitchSpecs.HEAP, BackendSwitchSpecs.ROCKS);
     }
 }